diff --git a/.github/workflows/build_cmake.yml b/.github/workflows/build_cmake.yml deleted file mode 100644 index 5e74952eab..0000000000 --- a/.github/workflows/build_cmake.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: CMake Build Matrix - -on: [push, pull_request] - -env: - CMAKE_VERSION: 3.28.0 # Oldest supported - NINJA_VERSION: 1.12.1 # Latest - BUILD_TYPE: Release - CCACHE_VERSION: 4.10.2 # Latest - NINJA_STATUS: "[%f/%t %o/sec] " - -jobs: - build: - name: ${{ matrix.config.name }} - runs-on: ${{ matrix.config.os }} - strategy: - fail-fast: false - matrix: - config: - - { - name: "Windows MSVC 2022", artifact: "Windows-MSVC-2022.7z", - os: windows-2022, - cc: "cl", cxx: "cl", - environment_script: "C:/Program Files/Microsoft Visual Studio/2022/Enterprise/VC/Auxiliary/Build/vcvars64.bat", - gpu_var: "Off", - openmp_var: "On" - } - - { - name: "macOS Latest Clang", artifact: "macOS.7z", - # In a release branch, we should fix this for the lifetime - # of the branch. - os: macos-latest, - cc: "clang", cxx: "clang++", - gpu_var: "Off", - openmp_var: "Off" - } - - { - name: "macOS Latest GCC 14 with OpenCL", artifact: "macOS-gcc-OpenCL.7z", - # In a release branch, we should fix this for the lifetime - # of the branch. - os: macos-latest, - cc: "gcc-14", cxx: "g++-14", - gpu_var: "OpenCL", - openmp_var: "On" - } - - env: - CC: ${{ matrix.config.cc }} - CXX: ${{ matrix.config.cxx }} - CI_JOB_ID: ${{ github.sha }} # Tell CMake it's running in CI - OPENMP_VAR: ${{ matrix.config.openmp_var }} - GPU_VAR: ${{ matrix.config.gpu_var }} - ENVIRONMENT_SCRIPT: ${{ matrix.config.environment_script }} - - steps: - - uses: actions/checkout@v4 - with: - show-progress: false - - - name: Download Ninja, CMake, and CCache - run: cmake -P .github/scripts/download-ninja-cmake.cmake - - - name: ccache cache files - uses: actions/cache@v4 - with: - path: .ccache - key: ${{ matrix.config.name }}-ccache-${{ github.sha }} - restore-keys: | - ${{ matrix.config.name }}-ccache- - - - name: Configure - run: cmake -P .github/scripts/configure.cmake - - - name: Build - run: cmake -P .github/scripts/build.cmake - - - name: Run tests - run: cmake -P .github/scripts/test.cmake - diff --git a/.github/workflows/metatomic-ci.yml b/.github/workflows/metatomic-ci.yml new file mode 100644 index 0000000000..2cc7ad5f0f --- /dev/null +++ b/.github/workflows/metatomic-ci.yml @@ -0,0 +1,79 @@ +name: Metatomic Integration Tests + +on: + workflow_dispatch: + inputs: + pixi_envs_branch: + description: "pixi_envs branch for test definitions" + required: false + default: "main" + push: + branches: [metatomic] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - mpi: thread-mpi + build_task: gromk-tmpi + gmx_bin: gmx + - mpi: real-mpi + build_task: gromk-mpi + gmx_bin: gmx_mpi + + name: metatomic (${{ matrix.mpi }}) + + steps: + - name: Checkout GROMACS + uses: actions/checkout@v4 + with: + path: gromacs + + - name: Checkout pixi_envs + uses: actions/checkout@v4 + with: + repository: HaoZeke/pixi_envs + ref: ${{ github.event.inputs.pixi_envs_branch || 'main' }} + path: pixi_envs + + - name: Place GROMACS inside pixi_envs layout + run: ln -s ${{ github.workspace }}/gromacs pixi_envs/orgs/metatensor/gromacs/gromacs + + - name: Install pixi + uses: prefix-dev/setup-pixi@v0.8.1 + with: + manifest-path: pixi_envs/orgs/metatensor/gromacs/pixi.toml + environments: metatomic-cpu + + - name: Build GROMACS (${{ matrix.mpi }}) + working-directory: pixi_envs/orgs/metatensor/gromacs + run: pixi run -e metatomic-cpu ${{ matrix.build_task }} Release + + - name: Show cmake error log on failure + if: failure() + working-directory: pixi_envs/orgs/metatensor/gromacs + run: | + for f in gromacs/build-*/CMakeFiles/CMakeError.log; do + echo "=== $f ===" + cat "$f" 2>/dev/null || true + done + for f in gromacs/build-*/CMakeFiles/CMakeOutput.log; do + echo "=== $f (last 100 lines) ===" + tail -100 "$f" 2>/dev/null || true + done + + - name: Generate test model + working-directory: pixi_envs/orgs/metatensor/gromacs + run: cd mta_test && pixi run -e metatomic-cpu python create_model.py + + - name: Run tests (skip dd8/dd12) + working-directory: pixi_envs/orgs/metatensor/gromacs + env: + GMX_BIN: ${{ matrix.gmx_bin }} + run: | + pixi run -e metatomic-cpu pytest mta_test/ -v \ + -m "not dd8 and not dd12" diff --git a/cmake/gmxManageMetatomic.cmake b/cmake/gmxManageMetatomic.cmake index e4a60c663f..03d3746a42 100644 --- a/cmake/gmxManageMetatomic.cmake +++ b/cmake/gmxManageMetatomic.cmake @@ -90,6 +90,25 @@ if(NOT GMX_METATOMIC STREQUAL "OFF") endif() set(GMX_TORCH ON) + + # Ensure the torch library directory is in RPATH so that libtorch.so, + # libc10.so, etc. can be found at runtime. CMAKE_INSTALL_RPATH_USE_LINK_PATH + # doesn't always extract paths from imported targets, so we add it explicitly. + # Guard prevents duplicate additions if gmxManageNNPot already added it. + if(TORCH_INSTALL_PREFIX AND NOT _torch_rpath_added) + list(APPEND CMAKE_INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib") + set(_torch_rpath_added TRUE) + + # Use RPATH instead of RUNPATH. Modern linkers default to RUNPATH + # (--enable-new-dtags), but RUNPATH doesn't propagate to transitive + # dependencies: gmx -> libgromacs.so -> libtorch.so -> libc10.so. + include(CheckLinkerFlag) + check_linker_flag(CXX "-Wl,--disable-new-dtags" _linker_supports_disable_new_dtags) + if(_linker_supports_disable_new_dtags) + add_link_options("-Wl,--disable-new-dtags") + endif() + endif() + elseif(GMX_METATOMIC STREQUAL "TORCH") message(FATAL_ERROR "Torch not found. Please install libtorch and add its installation prefix" " to CMAKE_PREFIX_PATH or set Torch_DIR to a directory containing " diff --git a/cmake/gmxManageNNPot.cmake b/cmake/gmxManageNNPot.cmake index ecaef22d6d..560d67d5bd 100644 --- a/cmake/gmxManageNNPot.cmake +++ b/cmake/gmxManageNNPot.cmake @@ -90,6 +90,22 @@ if(NOT GMX_NNPOT STREQUAL "OFF") endif() set(GMX_TORCH ON) + + # Ensure the torch library directory is in RPATH so that libtorch.so, + # libc10.so, etc. can be found at runtime. + if(TORCH_INSTALL_PREFIX AND NOT _torch_rpath_added) + list(APPEND CMAKE_INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib") + set(_torch_rpath_added TRUE) + + # Use RPATH instead of RUNPATH. RUNPATH doesn't propagate to + # transitive dependencies (gmx -> libgromacs -> libtorch -> libc10). + include(CheckLinkerFlag) + check_linker_flag(CXX "-Wl,--disable-new-dtags" _linker_supports_disable_new_dtags) + if(_linker_supports_disable_new_dtags) + add_link_options("-Wl,--disable-new-dtags") + endif() + endif() + elseif(GMX_NNPOT STREQUAL "TORCH") message(FATAL_ERROR "Torch not found. Please install libtorch and add its installation prefix" " to CMAKE_PREFIX_PATH or set Torch_DIR to a directory containing " diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 3cbeb9e964..f7aefacead 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -33,7 +33,24 @@ */ /*! \internal \file * \brief - * Implements the Metatomic Force Provider class with proper domain decomposition support. + * Implements the Metatomic Force Provider class with per-rank model evaluation. + * + * Uses the GROMACS plain pairlist (excludedPairlist) as the neighbor list + * source. MTA-MTA nonbonded pairs are excluded from the classical force + * calculation via intermolecularExclusionGroup, causing them to appear in + * excludedPairlist_ instead of pairlist_. Each pair is assigned to exactly + * one rank by the GROMACS nbnxm pairlist builder. + * + * Key design points: + * - Neighbor list: built from excludedPairlist_, not AnalysisNeighborhood. + * Cell shifts are negated (GROMACS shifts first atom, metatensor shifts + * second atom). Models requesting full_list get both (i,j) and (j,i). + * - Energy: selected_atoms = nullopt (sum all per-atom energies, home + + * halo). No double counting because each pair is on one rank. + * - Forces: all-reduce on a global buffer because ForceWithVirial is not + * communicated by dd_move_f. Only home atom forces are applied. + * - Ghost deduplication: periodic ghost images share the same model index + * but all GROMACS local indices are mapped via gmxLocalToMtaIdx_. * * \author Metatensor developers * \ingroup module_applied_forces @@ -43,6 +60,11 @@ #include "metatomic_forceprovider.h" #include +#include + +#include +#include +#include #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" @@ -50,12 +72,15 @@ #include "gromacs/mdtypes/enerdata.h" #include "gromacs/mdtypes/forceoutput.h" #include "gromacs/pbcutil/ishift.h" +#include "gromacs/pbcutil/pbc.h" #include "gromacs/utility/arrayref.h" #include "gromacs/utility/exceptions.h" #include "gromacs/utility/logger.h" #include "gromacs/utility/mpicomm.h" #include "gromacs/utility/stringutil.h" +#include "metatomic_timer.h" + #ifdef DIM # undef DIM #endif @@ -71,6 +96,7 @@ namespace gmx { +/*! \brief Normalizes the variant string for Metatomic output selection. */ static torch::optional normalize_variant(std::string variant_string) { if (variant_string == "no" || variant_string.empty()) @@ -83,17 +109,7 @@ static torch::optional normalize_variant(std::string variant_string } } - -static std::optional indexOf(ArrayRef vec, const int32_t val) -{ - auto it = std::find(vec.begin(), vec.end(), val); - if (it == vec.end()) - { - return std::nullopt; - } - return std::distance(vec.begin(), it); -} - +/*! \brief Converts GROMACS PbcType to a boolean tensor for Metatomic. */ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { auto options = torch::TensorOptions().dtype(torch::kBool).device(device); @@ -108,80 +124,12 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) } else if (*pbcType != PbcType::Xyz) { - GMX_THROW(InconsistentInputError("PBC type not supported.")); + GMX_THROW(InconsistentInputError("PBC type not supported by Metatomic interface.")); } return torch::tensor({ true, true, true }, options); } -static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, - ArrayRef shiftVectors, - ArrayRef cellShifts, - ArrayRef positions, - torch::Device device, - torch::ScalarType dtype) -{ - const int64_t n_pairs = static_cast(pairlist.size() / 2); - - auto cpu_int_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - auto cpu_float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); - - auto pair_samples_values = torch::zeros({ n_pairs, 5 }, cpu_int_options); - auto pair_samples_ptr = pair_samples_values.accessor(); - - // Full interatomic vectors (rj - ri + shift) - auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, cpu_float_options); - auto vectors_accessor = vectors_cpu.accessor(); - - for (int64_t i = 0; i < n_pairs; i++) - { - int32_t atom_i = pairlist[2 * i]; - int32_t atom_j = pairlist[2 * i + 1]; - - // Access IVec elements (integers) - pair_samples_ptr[i][0] = static_cast(atom_i); - pair_samples_ptr[i][1] = static_cast(atom_j); - pair_samples_ptr[i][2] = cellShifts[i][0]; - pair_samples_ptr[i][3] = cellShifts[i][1]; - pair_samples_ptr[i][4] = cellShifts[i][2]; - - // Calculate r_ij = r_j - r_i + shift - double r_ij_x = - static_cast(positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]); - double r_ij_y = - static_cast(positions[atom_j][1] - positions[atom_i][1] + shiftVectors[i][1]); - double r_ij_z = - static_cast(positions[atom_j][2] - positions[atom_i][2] + shiftVectors[i][2]); - - vectors_accessor[i][0][0] = r_ij_x; - vectors_accessor[i][1][0] = r_ij_y; - vectors_accessor[i][2][0] = r_ij_z; - } - - auto final_samples_values = pair_samples_values.to(device); - auto final_vectors = vectors_cpu.to(dtype).to(device); - - auto neighbor_samples = torch::make_intrusive( - std::vector{ - "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" }, - final_samples_values); - - auto neighbor_component = torch::make_intrusive( - std::vector{ "xyz" }, - torch::tensor({ 0, 1, 2 }, torch::TensorOptions().dtype(torch::kInt32).device(device)) - .reshape({ 3, 1 })); - - auto neighbor_properties = torch::make_intrusive( - std::vector{ "distance" }, - torch::zeros({ 1, 1 }, torch::TensorOptions().dtype(torch::kInt32).device(device))); - - return torch::make_intrusive( - final_vectors, - neighbor_samples, - std::vector{ neighbor_component }, - neighbor_properties); -} - - +/*! \brief Internal data structure for Metatomic runtime states. */ struct MetatomicData { metatensor_torch::Module model = metatensor_torch::Module(torch::jit::Module()); @@ -191,6 +139,14 @@ struct MetatomicData torch::ScalarType dtype = torch::kFloat32; bool check_consistency = false; torch::Device device = torch::kCPU; + + //! Cached NL Labels that are identical every step (created once in constructor). + metatensor_torch::Labels cachedNLComponent; + metatensor_torch::Labels cachedNLProperties; + //! Cached sample column names (avoids heap-allocating string vector every step). + std::vector nlSampleNames = { + "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" + }; }; MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, @@ -204,107 +160,143 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, { GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider..."); - // Pairlist-based neighbor lists don't work with domain decomposition yet (indices are local) - // Matches NNPot's limitation - if (mpiComm_.isParallel()) + if (const char* timerEnv = std::getenv("GMX_METATOMIC_TIMER")) { - GMX_THROW(NotImplementedError( - "Metatomic does not yet support domain decomposition. " - "Use thread-MPI (gmx mdrun) instead of MPI (mpirun gmx_mpi mdrun).")); + MetatomicTimer::enable(std::string(timerEnv) == "1"); } - // Only main rank loads model - if (mpiComm_.isMainRank()) + // With thread-MPI, each rank is a thread sharing the same process. + // PyTorch's internal OpenMP would spawn N threads per rank, causing + // massive oversubscription (e.g. 12 ranks × 12 OMP threads = 144 + // threads on 12 cores). Force single-threaded torch operations. + if (GMX_THREAD_MPI && mpiComm_.isParallel()) { - try - { - torch::optional extensions_directory = torch::nullopt; - if (!options_.params_.extensionsDirectory.empty()) - { - extensions_directory = options_.params_.extensionsDirectory; - } + at::set_num_threads(1); + } - data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, - extensions_directory); - } - catch (const std::exception& e) + try + { + torch::optional extensions_directory = torch::nullopt; + if (!options_.params_.extensionsDirectory.empty()) { - GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); + extensions_directory = options_.params_.extensionsDirectory; } - data_->capabilities = - data_->model.run_method("capabilities").toCustomClass(); + data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, + extensions_directory); + } + catch (const std::exception& e) + { + GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); + } - // Determine device using capabilities and optional environment variable - torch::optional desiredDevice = torch::nullopt; - if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) - { - desiredDevice = std::string(env); - } + data_->capabilities = + data_->model.run_method("capabilities").toCustomClass(); - const auto deviceType = - metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); - data_->device = torch::Device(deviceType); + torch::optional desiredDevice = torch::nullopt; + if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) + { + desiredDevice = std::string(env); + } - GMX_LOG(logger_.info) - .asParagraph() - .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); + const auto deviceType = + metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); + data_->device = torch::Device(deviceType); - auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); - for (const auto& request_ivalue : requests_ivalue.toList()) - { - data_->nl_requests.push_back( - request_ivalue.get().toCustomClass()); - } + // Cache NL Labels that are constant across steps (avoids per-step + // string vector + tensor allocation for component and properties). + auto devIntOpts = torch::TensorOptions().dtype(torch::kInt32).device(data_->device); + data_->cachedNLComponent = torch::make_intrusive( + std::vector{ "xyz" }, + torch::tensor({ 0, 1, 2 }, devIntOpts).reshape({ 3, 1 })); + data_->cachedNLProperties = torch::make_intrusive( + std::vector{ "distance" }, + torch::zeros({ 1, 1 }, devIntOpts)); - data_->model.to(data_->device); + GMX_LOG(logger_.info) + .asParagraph() + .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); - if (data_->capabilities->dtype() == "float64") - { - data_->dtype = torch::kFloat64; - } - else if (data_->capabilities->dtype() == "float32") + { + double interactionRange = data_->capabilities->engine_interaction_range("nm"); + std::string fname = "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "w"); + if (fp) { - data_->dtype = torch::kFloat32; + std::fprintf(fp, + "=== Metatomic init (rank %d) ===\n" + "interaction_range(nm)=%.6f\n", + mpiComm_.rank(), + interactionRange); + std::fclose(fp); } - else + } + + auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); + for (const auto& request_ivalue : requests_ivalue.toList()) + { + auto nl_opt = request_ivalue.get().toCustomClass(); { - GMX_THROW(APIError("Unsupported dtype: " + data_->capabilities->dtype())); + std::string fname = "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) + { + std::fprintf(fp, + "NL request: cutoff()=%.6f, engine_cutoff(nm)=%.6f, full_list=%s\n", + nl_opt->cutoff(), + nl_opt->engine_cutoff("nm"), + nl_opt->full_list() ? "true" : "false"); + std::fclose(fp); + } } + data_->nl_requests.push_back(nl_opt); + } - data_->evaluations_options = - torch::make_intrusive(); - data_->evaluations_options->set_length_unit("nm"); + data_->model.to(data_->device); - auto outputs = data_->capabilities->outputs(); - auto v_energy = normalize_variant(options_.params_.variant); - auto energy_key = pick_output("energy", outputs, v_energy); + if (data_->capabilities->dtype() == "float64") + { + data_->dtype = torch::kFloat64; + } + else if (data_->capabilities->dtype() == "float32") + { + data_->dtype = torch::kFloat32; + } + else + { + GMX_THROW(APIError("Unsupported dtype from model capabilities: " + data_->capabilities->dtype())); + } - if (!outputs.contains(energy_key)) - { - GMX_THROW(APIError("the model at '" + options_.params_.modelPath_ - + "' does not provide " - "an '" - + energy_key + "' output, we can not use the metatomic interface.")); - } + data_->evaluations_options = torch::make_intrusive(); + data_->evaluations_options->set_length_unit("nm"); - auto requested_output = torch::make_intrusive(); - // TODO: take from the user - requested_output->per_atom = false; - requested_output->explicit_gradients = {}; - requested_output->set_unit("kJ/mol"); + auto outputs = data_->capabilities->outputs(); + auto v_energy = normalize_variant(options_.params_.variant); + auto energy_key = pick_output("energy", outputs, v_energy); - data_->evaluations_options->outputs.insert(energy_key, requested_output); - data_->check_consistency = options_.params_.checkConsistency; + if (!outputs.contains(energy_key)) + { + GMX_THROW( + APIError(formatString("The model at '%s' does not provide an '%s' output. " + "Metatomic interface cannot proceed.", + options_.params_.modelPath_.c_str(), + energy_key.c_str()))); } - const auto& mtaIndices = options_.params_.mtaIndices_; - const int32_t n_atoms = static_cast(mtaIndices.size()); + auto requested_output = torch::make_intrusive(); + // per_atom=true so the model returns per-atom energies (needed for + // correct energy decomposition when using the GROMACS pairlist in DD) + requested_output->per_atom = true; + requested_output->explicit_gradients = {}; + requested_output->set_unit("kJ/mol"); + + data_->evaluations_options->outputs.insert(energy_key, requested_output); + data_->check_consistency = options_.params_.checkConsistency; - positions_.resize(n_atoms); - atomNumbers_.resize(n_atoms, 0); - inputToLocalIndex_.resize(n_atoms, -1); - inputToGlobalIndex_.resize(n_atoms, -1); + // Allocate global force buffer sized to total MTA atoms + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t n_total = static_cast(mtaIndices.size()); + globalForceBuffer_.resize(n_total, RVec({ 0.0, 0.0, 0.0 })); GMX_LOG(logger_.info) .asParagraph() @@ -313,30 +305,15 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, MetatomicForceProvider::~MetatomicForceProvider() = default; -/*! \brief Updates the mapping between GROMACS local/global atom indices and the Metatomic model's input atoms. - * - * This function is subscribed to the `MDModulesAtomsRedistributedSignal`. It is called whenever - * atoms are redistributed across MPI ranks (Domain Decomposition) or reordered in memory (sorting). - * - * Its primary responsibilities are: - * 1. **Locate Input Atoms**: It iterates through the local atoms on the current rank to find - * which atoms correspond to the "input atoms" defined for the Metatomic model (via `mtaIndices`). - * 2. **Update Maps**: It populates `inputToLocalIndex_` (mapping model input index -> GROMACS local index) - * and `inputToGlobalIndex_` (mapping model input index -> GROMACS global tag). - * 3. **Gather Atomic Numbers**: It ensures `atomNumbers_` (Z numbers) are correctly associated with - * the current local atoms, performing an MPI reduction if necessary to gather data from - * distributed ranks. - * - * \param[in] signal Contains the new mapping of global atom indices to local buffer indices after redistribution. - */ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { - const auto& mtaIndices = options_.params_.mtaIndices_; - const int32_t numInput = static_cast(mtaIndices.size()); + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t numTotalMta = static_cast(mtaIndices.size()); - inputToLocalIndex_.assign(numInput, -1); - inputToGlobalIndex_.assign(numInput, -1); - atomNumbers_.assign(numInput, 0); + mtaToGmxLocal_.clear(); + mtaToGlobalMta_.clear(); + atomNumbers_.clear(); + gmxLocalToMtaIdx_.clear(); if (mpiComm_.isParallel()) { @@ -344,36 +321,147 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist "Global atom indices required for domain decomposition."); auto globalAtomIndices = signal.globalAtomIndices_.value(); const int32_t numLocal = signal.x_.size(); + const int32_t numLocalPlusHalo = globalAtomIndices.size(); + + // Build a map from global atom index to MTA index for fast lookup + std::unordered_map globalToMtaIdx; + for (int32_t j = 0; j < numTotalMta; j++) + { + globalToMtaIdx[static_cast(mtaIndices[j])] = j; + } - for (int32_t i = 0; i < static_cast(globalAtomIndices.size()); i++) + // Separate home and halo MTA atoms, deduplicating periodic ghosts. + // Each unique MTA atom gets one model index. Periodic ghost images + // are NOT added as separate model atoms, but their GROMACS local + // buffer indices ARE recorded in gmxLocalToMtaIdx_ so that + // setPairlist can resolve pairlist entries referencing any image. + std::vector homeGmxLocal; + std::vector homeGlobalMta; + std::vector haloGmxLocal; + std::vector haloGlobalMta; + + // First pass: assign model indices to unique MTA atoms + std::unordered_map mtaIdxToModelIdx; + int32_t numDuplicatesSkipped = 0; + + for (int32_t i = 0; i < numLocalPlusHalo; i++) { int32_t globalIdx = globalAtomIndices[i]; - for (int32_t j = 0; j < numInput; j++) + auto it = globalToMtaIdx.find(globalIdx); + if (it != globalToMtaIdx.end()) { - if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) + int32_t mtaIdx = it->second; + + if (mtaIdxToModelIdx.count(mtaIdx)) + { + // Periodic ghost: record mapping but don't create new model atom + numDuplicatesSkipped++; + } + else { if (i < numLocal) { - inputToLocalIndex_[j] = i; - inputToGlobalIndex_[j] = globalIdx; - atomNumbers_[j] = options_.params_.atoms_.atom[globalIdx].atomnumber; + // Will be assigned model index = homeGmxLocal.size() (filled later) + homeGmxLocal.push_back(i); + homeGlobalMta.push_back(mtaIdx); } - break; + else + { + haloGmxLocal.push_back(i); + haloGlobalMta.push_back(mtaIdx); + } + // Placeholder: model index will be set after we know numHomeMta_ + mtaIdxToModelIdx[mtaIdx] = -1; } } } - mpiComm_.sumReduce(numInput, atomNumbers_.data()); + + // Assign final model indices: home [0, numHome), halo [numHome, numLocal) + int32_t modelIdx = 0; + for (int32_t k = 0; k < static_cast(homeGmxLocal.size()); k++) + { + mtaIdxToModelIdx[homeGlobalMta[k]] = modelIdx++; + } + for (int32_t k = 0; k < static_cast(haloGmxLocal.size()); k++) + { + mtaIdxToModelIdx[haloGlobalMta[k]] = modelIdx++; + } + + // Second pass: build complete gmxLocal → modelIdx mapping for ALL images + for (int32_t i = 0; i < numLocalPlusHalo; i++) + { + int32_t globalIdx = globalAtomIndices[i]; + auto it = globalToMtaIdx.find(globalIdx); + if (it != globalToMtaIdx.end()) + { + gmxLocalToMtaIdx_[i] = mtaIdxToModelIdx[it->second]; + } + } + + { + std::string fname = "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) + { + std::fprintf(fp, + "gatherAtoms: home=%zu halo=%zu duplicatesSkipped=%d gmxLocalEntries=%zu\n", + homeGmxLocal.size(), + haloGmxLocal.size(), + numDuplicatesSkipped, + gmxLocalToMtaIdx_.size()); + std::fclose(fp); + } + } + + numHomeMta_ = static_cast(homeGmxLocal.size()); + numLocalMta_ = numHomeMta_ + static_cast(haloGmxLocal.size()); + + // Assign local model indices: home -> [0, numHomeMta_), halo -> [numHomeMta_, numLocalMta_) + mtaToGmxLocal_.resize(numLocalMta_); + mtaToGlobalMta_.resize(numLocalMta_); + atomNumbers_.resize(numLocalMta_); + + for (int32_t k = 0; k < numHomeMta_; k++) + { + int32_t gmxLocal = homeGmxLocal[k]; + int32_t globalIdx = globalAtomIndices[gmxLocal]; + + mtaToGmxLocal_[k] = gmxLocal; + mtaToGlobalMta_[k] = homeGlobalMta[k]; + atomNumbers_[k] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } + + for (int32_t k = 0; k < static_cast(haloGmxLocal.size()); k++) + { + int32_t modelIdx = numHomeMta_ + k; + int32_t gmxLocal = haloGmxLocal[k]; + int32_t globalIdx = globalAtomIndices[gmxLocal]; + + mtaToGmxLocal_[modelIdx] = gmxLocal; + mtaToGlobalMta_[modelIdx] = haloGlobalMta[k]; + atomNumbers_[modelIdx] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } } else { + // Serial / thread-MPI: all MTA atoms are home, no halos const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); - for (int32_t i = 0; i < numInput; i++) + numHomeMta_ = numTotalMta; + numLocalMta_ = numTotalMta; + + mtaToGmxLocal_.resize(numTotalMta); + mtaToGlobalMta_.resize(numTotalMta); + atomNumbers_.resize(numTotalMta); + + for (int32_t i = 0; i < numTotalMta; i++) { - int32_t localIndex = mtaAtoms->localIndex()[i]; - int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; - inputToLocalIndex_[i] = localIndex; - inputToGlobalIndex_[i] = globalIdx; - atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; + int32_t localIndex = mtaAtoms->localIndex()[i]; + int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + + mtaToGmxLocal_[i] = localIndex; + mtaToGlobalMta_[i] = i; + atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; + gmxLocalToMtaIdx_[localIndex] = i; } } @@ -381,102 +469,65 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist "Some atom numbers not set."); } -/*! \brief Updates the internal position buffer with the coordinates of the Metatomic atoms. - * - * This function extracts the coordinates of the atoms relevant to the Metatomic model - * from the full GROMACS local atom array (`pos`). - * - * 1. **Filtering:** `pos` contains all local atoms (solute, solvent, ions). - * This function copies only the atoms defined in `mtaIndices` to `positions_`. - * 2. **Ordering/Packing:** GROMACS reorders atoms dynamically for Domain Decomposition. - * This function uses `inputToLocalIndex_` to collect atoms and pack them into a contiguous, ordered buffer - * - * \param[in] pos The array of all local atom coordinates for the current step. - */ void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) { - const size_t numInput = inputToLocalIndex_.size(); - positions_.assign(numInput, RVec({ 0.0, 0.0, 0.0 })); - - for (size_t i = 0; i < numInput; i++) - { - if (inputToLocalIndex_[i] != -1) - { - positions_[i] = pos[inputToLocalIndex_[i]]; - } - } - - if (mpiComm_.isParallel()) + positions_.resize(numLocalMta_); + for (int32_t i = 0; i < numLocalMta_; i++) { - mpiComm_.sumReduce(3 * numInput, positions_.data()->as_vec()); + positions_[i] = pos[mtaToGmxLocal_[i]]; } } void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& signal) { - fullPairlist_.assign(signal.excludedPairlist_.begin(), signal.excludedPairlist_.end()); - doPairlist_ = true; -} - -void MetatomicForceProvider::preparePairlistInput() -{ - if (!doPairlist_) + pairlistMta_.clear(); + cellShiftsMta_.clear(); + + // Use gmxLocalToMtaIdx_ which maps ALL GROMACS local buffer indices + // (including periodic ghost images) to their MTA model index. + // + // Sign convention: GROMACS shifts atom I (first): d = x[I]+shift - x[J]. + // Metatensor shifts atom J (second): r_ij = x[J]+cell·box - x[I]. + // So metatensor cell shift = -GROMACS cell shift. + for (const auto& entry : signal.excludedPairlist_) { - return; - } - - GMX_ASSERT(!fullPairlist_.empty(), "Pairlist empty!"); - - const int32_t numPairs = gmx::ssize(fullPairlist_); - pairlistForModel_.clear(); - pairlistForModel_.reserve(2 * numPairs); - shiftVectors_.clear(); - shiftVectors_.reserve(numPairs); - cellShifts_.clear(); - cellShifts_.reserve(numPairs); - - for (int32_t i = 0; i < numPairs; i++) - { - const auto [atomPair, shiftIndex] = fullPairlist_[i]; - - auto inputIdxA = indexOf(inputToGlobalIndex_, atomPair.first); - auto inputIdxB = indexOf(inputToGlobalIndex_, atomPair.second); - - if (inputIdxA.has_value() && inputIdxB.has_value()) + const auto& [atomPair, shiftIndex] = entry; + auto itA = gmxLocalToMtaIdx_.find(atomPair.first); + auto itB = gmxLocalToMtaIdx_.find(atomPair.second); + if (itA != gmxLocalToMtaIdx_.end() && itB != gmxLocalToMtaIdx_.end()) { - RVec shift; - const IVec unitShift = shiftIndexToXYZ(shiftIndex); - mvmul_ur0(box_, unitShift.toRVec(), shift); - - pairlistForModel_.push_back(static_cast(inputIdxA.value())); - pairlistForModel_.push_back(static_cast(inputIdxB.value())); - shiftVectors_.push_back(shift); - cellShifts_.push_back(unitShift); + pairlistMta_.push_back(itA->second); + pairlistMta_.push_back(itB->second); + const IVec gmxShift = shiftIndexToXYZ(shiftIndex); + cellShiftsMta_.push_back(IVec(-gmxShift[XX], -gmxShift[YY], -gmxShift[ZZ])); } } - - GMX_RELEASE_ASSERT(pairlistForModel_.size() == shiftVectors_.size() * 2, - "Pairlist/shift size mismatch."); - doPairlist_ = false; } + void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) { - const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); + MetatomicTimer totalTimer("calculateForces", mpiComm_); - gatherAtomPositions(inputs.x_); - copy_mat(inputs.box_, box_); - preparePairlistInput(); + const int32_t numTotalMta = static_cast(options_.params_.mtaIndices_.size()); - // Force tensor - main rank fills, others have zeros - torch::Tensor forceTensor = torch::zeros( - { n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64).device(data_->device)); + // Fill local positions (no MPI communication) + { + MetatomicTimer timer("gatherAtomPositions", mpiComm_); + gatherAtomPositions(inputs.x_); + } + copy_mat(inputs.box_, box_); - // Virial tensor for pressure/stress calculations - torch::Tensor virialTensor = torch::zeros({ 3, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); + // Model inference + torch::Tensor forceTensor; + torch::Tensor virialTensor; + double energy = 0.0; - if (mpiComm_.isMainRank()) { + MetatomicTimer modelTimer("model inference", mpiComm_); + + MetatomicTimer tensorPrepTimer("tensorPrep", mpiComm_); + auto gromacs_scalar_type = torch::kFloat32; if (std::is_same_v) { @@ -484,7 +535,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } auto cpu_blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); - auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, cpu_blob_options) + auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { static_cast(numLocalMta_), 3 }, cpu_blob_options) .to(data_->dtype) .to(data_->device) .set_requires_grad(true); @@ -492,13 +543,10 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto torch_cell = torch::from_blob(&box_, { 3, 3 }, cpu_blob_options).to(data_->dtype).to(data_->device); - // Create strain tensor for virial computation (like LAMMPS does) auto strain = torch::eye( 3, torch::TensorOptions().dtype(data_->dtype).device(data_->device).requires_grad(true)); - // Apply strain to cell: strained_cell = cell @ strain - auto strained_cell = torch::matmul(torch_cell, strain); - // Apply strain to positions: r' = r @ strain + auto strained_cell = torch::matmul(torch_cell, strain); auto strained_positions = torch::matmul(torch_positions, strain); auto torch_pbc = preparePbcType(options_.params_.pbcType_.get(), data_->device); @@ -508,15 +556,101 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto system = torch::make_intrusive( torch_types, strained_positions, strained_cell, torch_pbc); - // Build neighbor list from GROMACS pairlist + tensorPrepTimer.stop(); + + // Build NL directly into raw buffers, then wrap with from_blob. + // Shift vectors are computed inline (no separate prepareNL pass). + // Component and properties Labels are cached (identical every step). + MetatomicTimer buildNLTimer("buildNL", mpiComm_); + for (const auto& request : data_->nl_requests) { - auto neighbors = buildNeighborListFromPairlist( - pairlistForModel_, shiftVectors_, cellShifts_, positions_, data_->device, data_->dtype); + const int64_t nHalf = static_cast(pairlistMta_.size() / 2); + const bool full = request->full_list(); + const int64_t nPairs = full ? 2 * nHalf : nHalf; + + nlSamplesBuffer_.resize(nPairs * 5); + nlVectorsBuffer_.resize(nPairs * 3); + + for (int64_t k = 0; k < nHalf; k++) + { + const int32_t ai = pairlistMta_[2 * k]; + const int32_t aj = pairlistMta_[2 * k + 1]; + + // Compute shift vector from cell shift and current box + RVec shift; + mvmul_ur0(inputs.box_, cellShiftsMta_[k].toRVec(), shift); + + // Displacement: r_ij = pos[j] - pos[i] + shift (metatensor convention) + const double dx = static_cast(positions_[aj][0] - positions_[ai][0] + shift[0]); + const double dy = static_cast(positions_[aj][1] - positions_[ai][1] + shift[1]); + const double dz = static_cast(positions_[aj][2] - positions_[ai][2] + shift[2]); + + const int64_t fwd = full ? 2 * k : k; + nlSamplesBuffer_[5 * fwd + 0] = ai; + nlSamplesBuffer_[5 * fwd + 1] = aj; + nlSamplesBuffer_[5 * fwd + 2] = cellShiftsMta_[k][0]; + nlSamplesBuffer_[5 * fwd + 3] = cellShiftsMta_[k][1]; + nlSamplesBuffer_[5 * fwd + 4] = cellShiftsMta_[k][2]; + nlVectorsBuffer_[3 * fwd + 0] = dx; + nlVectorsBuffer_[3 * fwd + 1] = dy; + nlVectorsBuffer_[3 * fwd + 2] = dz; + + if (full) + { + // Reverse pair (j,i) with negated shifts and displacement + const int64_t rev = 2 * k + 1; + nlSamplesBuffer_[5 * rev + 0] = aj; + nlSamplesBuffer_[5 * rev + 1] = ai; + nlSamplesBuffer_[5 * rev + 2] = -cellShiftsMta_[k][0]; + nlSamplesBuffer_[5 * rev + 3] = -cellShiftsMta_[k][1]; + nlSamplesBuffer_[5 * rev + 4] = -cellShiftsMta_[k][2]; + nlVectorsBuffer_[3 * rev + 0] = -dx; + nlVectorsBuffer_[3 * rev + 1] = -dy; + nlVectorsBuffer_[3 * rev + 2] = -dz; + } + } + + MetatomicTimer fromBlobTimer("fromBlob", mpiComm_); + auto samples_tensor = torch::from_blob( + nlSamplesBuffer_.data(), { nPairs, 5 }, + torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); + auto vectors_tensor = torch::from_blob( + nlVectorsBuffer_.data(), { nPairs, 3, 1 }, + torch::TensorOptions().dtype(torch::kFloat64)).to(data_->dtype).to(data_->device); + fromBlobTimer.stop(); + + MetatomicTimer labelsTimer("makeSampleLabels", mpiComm_); + auto neighbor_samples = torch::make_intrusive( + data_->nlSampleNames, samples_tensor); + labelsTimer.stop(); + + MetatomicTimer blockTimer("makeTensorBlock", mpiComm_); + auto neighbors = torch::make_intrusive( + vectors_tensor, + neighbor_samples, + std::vector{ data_->cachedNLComponent }, + data_->cachedNLProperties); + blockTimer.stop(); + + MetatomicTimer autogradTimer("registerAutograd", mpiComm_); metatomic_torch::register_autograd_neighbors(system, neighbors, data_->check_consistency); + autogradTimer.stop(); + + MetatomicTimer addNLTimer("addNeighborList", mpiComm_); system->add_neighbor_list(request, neighbors); + addNLTimer.stop(); } + buildNLTimer.stop(); + + // No selected_atoms: each pair is on exactly one rank, so summing + // all per-atom energies (home + halo) gives the correct pair energy. + // GROMACS sums across ranks via global_stat. + data_->evaluations_options->set_selected_atoms(torch::nullopt); + + MetatomicTimer forwardTimer("forward", mpiComm_); + metatensor_torch::TensorMap output_map; try { @@ -533,51 +667,110 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F GMX_THROW(APIError("[Metatomic] Model evaluation failed: " + std::string(e.what()))); } + forwardTimer.stop(); + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); auto energy_tensor = energy_block->values(); - outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = - static_cast(energy_tensor.sum().item()); + energy = energy_tensor.sum().item(); + + // Diagnostic: log pairlist size, per-rank energy and MPI sum + { + double mpiSumEnergy = energy; + if (mpiComm_.isParallel()) + { + mpiComm_.sumReduce(1, &mpiSumEnergy); + } + std::string fname = + "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) + { + std::fprintf(fp, + "pairlistPairs=%zu, numLocalMta=%d, numHomeMta=%d, " + "energy: perRank=%.6f, mpiSum=%.6f\n", + pairlistMta_.size() / 2, + numLocalMta_, + numHomeMta_, + energy, + mpiSumEnergy); + std::fclose(fp); + } + } + + MetatomicTimer backwardTimer("backward", mpiComm_); - // Reset gradients before backward torch_positions.mutable_grad() = torch::Tensor(); strain.mutable_grad() = torch::Tensor(); - // Compute forces and virial via backward propagation energy_tensor.backward(-torch::ones_like(energy_tensor)); - auto grad = torch_positions.grad(); - forceTensor = grad.to(torch::kCPU).to(torch::kFloat64); + backwardTimer.stop(); + + MetatomicTimer toCPUTimer("toCPU", mpiComm_); - // Get virial from strain gradient + forceTensor = torch_positions.grad().to(torch::kCPU).to(torch::kFloat64); virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); + + toCPUTimer.stop(); } - // Distribute forces (sumReduce acts as broadcast since non-main ranks have zeros) + // Force distribution via all-reduce. + // backward() produces forces on ALL local atoms (home + halo). Since + // ForceWithVirial forces are NOT communicated by dd_move_f (which only + // handles ForceWithShiftForces), we must all-reduce ourselves. Each rank + // scatters its local forces into a global buffer indexed by global MTA + // index. After all-reduce, each rank reads back only its home atoms. + MetatomicTimer forceScatterTimer("forceScatter", mpiComm_); + + auto forceAccessor = forceTensor.accessor(); + if (mpiComm_.isParallel()) { - mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); - mpiComm_.sumReduce(9, static_cast(virialTensor.data_ptr())); - } + globalForceBuffer_.assign(numTotalMta, RVec({ 0.0, 0.0, 0.0 })); - // Apply forces to local atoms only - auto forceAccessor = forceTensor.accessor(); - for (int32_t i = 0; i < n_atoms; ++i) + for (int32_t i = 0; i < numLocalMta_; i++) + { + int32_t globalMtaIdx = mtaToGlobalMta_[i]; + globalForceBuffer_[globalMtaIdx][0] = static_cast(forceAccessor[i][0]); + globalForceBuffer_[globalMtaIdx][1] = static_cast(forceAccessor[i][1]); + globalForceBuffer_[globalMtaIdx][2] = static_cast(forceAccessor[i][2]); + } + + mpiComm_.sumReduce(3 * numTotalMta, globalForceBuffer_.data()->as_vec()); + + // Apply forces only to home MTA atoms from the reduced buffer + for (int32_t i = 0; i < numHomeMta_; i++) + { + int32_t gmxIdx = mtaToGmxLocal_[i]; + int32_t globalMtaIdx = mtaToGlobalMta_[i]; + outputs->forceWithVirial_.force_[gmxIdx][0] += globalForceBuffer_[globalMtaIdx][0]; + outputs->forceWithVirial_.force_[gmxIdx][1] += globalForceBuffer_[globalMtaIdx][1]; + outputs->forceWithVirial_.force_[gmxIdx][2] += globalForceBuffer_[globalMtaIdx][2]; + } + } + else { - if (inputToLocalIndex_[i] != -1) + // Serial: apply forces directly + for (int32_t i = 0; i < numLocalMta_; i++) { - outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][0] += forceAccessor[i][0]; - outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][1] += forceAccessor[i][1]; - outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][2] += forceAccessor[i][2]; + int32_t gmxIdx = mtaToGmxLocal_[i]; + outputs->forceWithVirial_.force_[gmxIdx][0] += static_cast(forceAccessor[i][0]); + outputs->forceWithVirial_.force_[gmxIdx][1] += static_cast(forceAccessor[i][1]); + outputs->forceWithVirial_.force_[gmxIdx][2] += static_cast(forceAccessor[i][2]); } } - // Apply virial contribution - // GROMACS uses a 3x3 virial tensor in forceWithVirial_ - // Copy the tensor data into a GROMACS matrix and use the public API + forceScatterTimer.stop(); + + // Energy: each rank's energy is the sum of per-atom energies for all local + // atoms (home + halo) from the pairs assigned to this rank. GROMACS + // global_stat sums across ranks to get the system total. + outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = static_cast(energy); + + // Virial: same decomposition as energy — per-rank portion, summed by GROMACS. matrix virialMatrix; auto virialAccessor = virialTensor.accessor(); - // TODO: technically this is DIM, not 3... for (int32_t i = 0; i < 3; ++i) { for (int32_t j = 0; j < 3; ++j) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index c26ce7b187..0448c0fed1 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -41,6 +41,8 @@ #pragma once +#include + #include "gromacs/mdtypes/iforceprovider.h" #include "metatomic_options.h" @@ -56,15 +58,43 @@ struct MDModulesPairlistConstructedSignal; class MDLogger; class MpiComm; -/*! For compatibility with pairlist data structure in MDModulesPairlistConstructedSignal. - * Contains pairs like ((atom1, atom2), shiftIndex). - */ -using PairlistEntry = std::pair, int32_t>; - /*! \brief \internal * MetatomicForceProvider class * * Implements the IForceProvider interface for the Metatomic force provider. + * + * ## Domain decomposition strategy + * + * Each rank evaluates the model on its local (home + halo) MTA atoms. + * The neighbor list comes from the GROMACS plain pairlist (excludedPairlist), + * which assigns each pair to exactly one rank — no double counting. + * + * **Atoms**: In DD, GROMACS partitions atoms into "home" atoms (owned by this + * rank) and "halo" atoms (copies from neighboring ranks needed for short-range + * interactions). The same global atom may appear as multiple periodic ghost + * images in the halo. We deduplicate these so each atom has one model index, + * but record ALL GROMACS local buffer indices in gmxLocalToMtaIdx_ so that + * the pairlist (which may reference any image) can be resolved. + * + * **Pairs**: MTA-MTA pairs are excluded from classical nonbonded interactions + * via intermolecularExclusionGroup (set by addEmbeddedNBExclusions). The + * GROMACS pairlist builder reports these excluded pairs in excludedPairlist_, + * filtered to the plainPairlistRange (= model cutoff). Each pair appears on + * exactly one rank. + * + * **Energy**: With per_atom=true, the model decomposes energy per atom. + * selected_atoms is always nullopt: we sum ALL per-atom energies (home + halo) + * on each rank. Since each pair is on one rank, the per-pair energy + * (V_ij/2 on atom i + V_ij/2 on atom j) sums to V_ij on that rank. + * GROMACS global_stat sums across ranks for the total. + * + * **Forces**: backward() produces forces on all local atoms (home + halo). + * Since ForceWithVirial is not communicated by dd_move_f, we scatter forces + * into a global buffer and MPI all-reduce, then apply only to home atoms. + * + * **Shift convention**: GROMACS shifts atom I (first): d = x[I]+shift - x[J]. + * Metatensor convention: r_ij = x[J] + cell_shift*box - x[I]. Therefore + * metatensor cell shifts = negated GROMACS cell shifts. */ class MetatomicForceProvider final : public IForceProvider { @@ -83,52 +113,62 @@ class MetatomicForceProvider final : public IForceProvider //! Gather atom numbers and indices. Triggered on AtomsRedistributed signal. void gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal); - //! Set pairlist from notification and filter to MTA atom pairs. + //! Store GROMACS pairlist and convert to MTA model indices. void setPairlist(const MDModulesPairlistConstructedSignal& signal); private: - //! Gather atom positions for MTA input. - void gatherAtomPositions(ArrayRef globalPositions); - - //! Prepare pairlist input for model - void preparePairlistInput(); + //! Gather atom positions for MTA input (local only, no MPI). + void gatherAtomPositions(ArrayRef positions); const MetatomicOptions& options_; const MDLogger& logger_; const MpiComm& mpiComm_; - //! vector storing all MTA atom positions + //! Positions of local MTA atoms, indexed by model index [0, numLocalMta_). + //! Home atoms occupy [0, numHomeMta_), halo atoms [numHomeMta_, numLocalMta_). std::vector positions_; - //! vector storing all atomic numbers + //! Atomic numbers of local MTA atoms, same indexing as positions_. std::vector atomNumbers_; - //! lookup table to map model input indices [0...numInput) to local atom indices - std::vector inputToLocalIndex_; - - //! lookup table to map model input indices to global atom indices - std::vector inputToGlobalIndex_; - - //! Full pairlist from MDModules notification - std::vector fullPairlist_; - - //! Interacting pairs of MTA atoms within cutoff, for model input - std::vector pairlistForModel_; - - //! Shift vectors for each atom pair in pairlistForModel_ - std::vector shiftVectors_; - - //! Cell shifts - std::vector cellShifts_; + //! Number of home (owned by this rank) MTA atoms. + int32_t numHomeMta_ = 0; + //! Number of unique local MTA atoms (home + halo, after deduplication). + int32_t numLocalMta_ = 0; + + //! Maps model index [0, numLocalMta_) -> GROMACS local buffer index (first occurrence). + //! Used for position gathering and force scattering. + std::vector mtaToGmxLocal_; + //! Maps model index [0, numLocalMta_) -> global MTA index [0, N_total_mta). + //! Used for scatter/gather in the global force buffer during MPI all-reduce. + std::vector mtaToGlobalMta_; + //! Maps ANY GROMACS local buffer index -> MTA model index. + //! Includes ALL periodic ghost images of each atom (not just the first). + //! Needed because excludedPairlist_ entries can reference any image. + std::unordered_map gmxLocalToMtaIdx_; + + //! Global force buffer [N_total_mta] for MPI all-reduce of forces. + //! Each rank scatters its local forces here, all-reduce sums them, + //! then home forces are read back. + std::vector globalForceBuffer_; + + //! Pairlist in MTA model indices, flat [i0,j0, i1,j1, ...]. + //! Built from GROMACS excludedPairlist_ with negated cell shifts. + std::vector pairlistMta_; + //! Cell shifts for each pair (metatensor convention: shift applied to second atom). + std::vector cellShiftsMta_; + + //! Pre-allocated raw buffers for NL construction. + //! Avoids per-step torch::zeros allocations and accessor overhead. + //! Filled directly, then wrapped with torch::from_blob (zero-cost). + std::vector nlSamplesBuffer_; //!< flat [n_pairs * 5]: i, j, cs_a, cs_b, cs_c + std::vector nlVectorsBuffer_; //!< flat [n_pairs * 3]: dx, dy, dz //! local copy of simulation box matrix box_; //! Data required for metatomic calculations std::unique_ptr data_; - - //! flag to check if pairlist data should be prepared - bool doPairlist_ = false; }; } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp index 302dc2559d..22fa70e6cc 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp @@ -80,11 +80,17 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& /*inputs* { } -void MetatomicForceProvider::updateLocalAtoms() {} -void MetatomicForceProvider::gatherAtomPositions(ArrayRef globalPositions) { - (void)globalPositions; +void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& /*signal*/) +{ +} + +void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& /*signal*/) +{ +} + +void MetatomicForceProvider::gatherAtomPositions(ArrayRef /*positions*/) +{ } -void MetatomicForceProvider::gatherAtomNumbersIndices() {} CLANG_DIAGNOSTIC_RESET diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 6f3270f0ed..41c3b6dde2 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -239,7 +239,16 @@ class MetatomicMDModule final : public IMDModule GMX_THROW(InconsistentInputError("Metatomic model cutoff is 0.0 or invalid.")); } } - // Register the requirement with GROMACS + // TODO: For multi-layer GNN models (MACE, NequIP, etc.) the + // interaction_range should be n_layers * cutoff so that DD halos + // are deep enough for message-passing. Many models currently + // report interaction_range == cutoff, which makes DD give wrong + // energies because halo atoms lack complete neighborhoods. + // Unlike LAMMPS (which adds a ~2 Å neighbor skin on top of the + // cutoff), GROMACS caps the DD range at rlist, so we cannot add + // extra range here. The model must report the correct + // interaction_range, or the user must increase rcoulomb/rvdw in + // the .mdp so that rlist >= interaction_range. ranges->addRange(max_cutoff); }; // Register the callback diff --git a/src/gromacs/applied_forces/metatomic/metatomic_timer.h b/src/gromacs/applied_forces/metatomic/metatomic_timer.h new file mode 100644 index 0000000000..e4f056a53f --- /dev/null +++ b/src/gromacs/applied_forces/metatomic/metatomic_timer.h @@ -0,0 +1,144 @@ +/* + * This file is part of the GROMACS molecular simulation package. + * + * Copyright 2024- The GROMACS Authors + * and the project initiators Erik Lindahl, Berk Hess and David van der Spoel. + * Consult the AUTHORS/COPYING files and https://www.gromacs.org for details. + * + * GROMACS is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public License + * as published by the Free Software Foundation; either version 2.1 + * of the License, or (at your option) any later version. + * + * GROMACS is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with GROMACS; if not, see + * https://www.gnu.org/licenses, or write to the Free Software Foundation, + * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * If you want to redistribute modifications to GROMACS, please + * consider that scientific software is very special. Version + * control is crucial - bugs must be traceable. We will be happy to + * consider code for inclusion in the official distribution, but + * derived work must not be called official GROMACS. Details are found + * in the README & COPYING files - if they are missing, get the + * official version at https://www.gromacs.org. + * + * To help us fund GROMACS development, we humbly ask that you cite + * the research papers on the package. Check out https://www.gromacs.org. + */ +/*! \internal \file + * \brief + * Scoped timer for Metatomic force provider profiling. + * + * RAII timer that writes nested timing information to per-rank files + * (metatomic_timer_rank_N.log). Enable with GMX_METATOMIC_TIMER=1. + * + * \author Metatensor developers + * \ingroup module_applied_forces + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "gromacs/utility/mpicomm.h" + +namespace gmx +{ + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static std::mutex METATOMIC_TIMER_MUTEX = {}; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static thread_local int64_t METATOMIC_TIMER_DEPTH = -1; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static bool METATOMIC_TIMER_ENABLED = false; + +/*! \internal \brief RAII scoped timer for Metatomic profiling. + * + * Writes hierarchical timing info to a per-rank file + * (metatomic_timer_rank_N.log). Timers nest automatically via a + * global depth counter. Thread-safe via a global mutex. + * + * Enable with GMX_METATOMIC_TIMER=1 environment variable. + */ +class MetatomicTimer +{ +public: + //! Enable or disable all timers globally. + static void enable(bool toggle) + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + METATOMIC_TIMER_ENABLED = toggle; + } + + //! Construct a timer with the given label. Starts timing if enabled. + MetatomicTimer(std::string name, const MpiComm& mpiComm) : + enabled_(false), name_(std::move(name)), mpiComm_(mpiComm) + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + if (METATOMIC_TIMER_ENABLED) + { + METATOMIC_TIMER_DEPTH += 1; + this->enabled_ = true; + this->start_ = std::chrono::high_resolution_clock::now(); + } + } + + //! Stop the timer early (before scope exit). Safe to call multiple times. + void stop() + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + recordAndDisable_(); + } + + ~MetatomicTimer() + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + recordAndDisable_(); + } + + // Non-copyable, non-movable + MetatomicTimer(const MetatomicTimer&) = delete; + MetatomicTimer& operator=(const MetatomicTimer&) = delete; + MetatomicTimer(MetatomicTimer&&) = delete; + MetatomicTimer& operator=(MetatomicTimer&&) = delete; + +private: + //! Record elapsed time and mark as done. Must be called under lock. + void recordAndDisable_() + { + if (METATOMIC_TIMER_ENABLED && this->enabled_) + { + auto stop = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(stop - start_).count(); + auto indent = std::string(METATOMIC_TIMER_DEPTH * 2, ' '); + + std::string fname = "metatomic_timer_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) + { + std::fprintf(fp, "%s%s: %.3f ms\n", indent.c_str(), name_.c_str(), elapsed / 1e3); + std::fclose(fp); + } + + this->enabled_ = false; + METATOMIC_TIMER_DEPTH -= 1; + } + } + + bool enabled_; + std::string name_; + const MpiComm& mpiComm_; + std::chrono::high_resolution_clock::time_point start_; +}; + +} // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml index 21555a3181..6ba0b09913 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml @@ -6,5 +6,5 @@ - true + false diff --git a/src/gromacs/domdec/localtopology.cpp b/src/gromacs/domdec/localtopology.cpp index 85d5ee547d..f1ca846fd7 100644 --- a/src/gromacs/domdec/localtopology.cpp +++ b/src/gromacs/domdec/localtopology.cpp @@ -846,6 +846,14 @@ static int make_local_bondeds_excls(const gmx_domdec_t& dd, /* We only use exclusions from i-zones to i- and j-zones */ const int numIZonesForExclusions = (dd.haveExclusions ? zones.numIZones() : 0); + /* When intermolecular exclusions are present (e.g. from embedded/ML potentials) + * but there are no inter-atomic bonded interactions spanning zones, the outer loop + * must still cover all i-zones so that exclusion lists are built for all i-zone atoms. + * Without this, the exclusion list would only cover zone 0 atoms while the nbnxm + * pairlist construction expects exclusions for all i-zone atoms. + */ + nzone_bondeds = std::max(nzone_bondeds, numIZonesForExclusions); + const gmx_reverse_top_t& rt = *dd.reverse_top; const real cutoffSquared = gmx::square(cutoff);