diff --git a/src/openfe/data/_registry.py b/src/openfe/data/_registry.py index 7a87814dd..80a2aade6 100644 --- a/src/openfe/data/_registry.py +++ b/src/openfe/data/_registry.py @@ -17,8 +17,15 @@ fname="industry_benchmark_systems.zip", known_hash="sha256:2bb5eee36e29b718b96bf6e9350e0b9957a592f6c289f77330cbb6f4311a07bd", ) +zenodo_resume_data = dict( + base_url="doi:10.5281/zenodo.18331259", + fname="multistate_checkpoints.zip", + known_hash="md5:6addeabbfa37fd5f9114e3b043bfa568", +) + zenodo_data_registry = [ zenodo_rfe_simulation_nc, zenodo_t4_lysozyme_traj, zenodo_industry_benchmark_systems, + zenodo_resume_data, ] diff --git a/src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index 299a846f6..50acdd8e5 100644 --- a/src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -32,7 +32,13 @@ class HybridCompatibilityMixin: unsampled endpoints have a different number of degrees of freedom. """ - def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + def __init__( + self, + *args, + hybrid_system: openmm.System | None = None, + hybrid_positions: unit.Quantity | None = None, + **kwargs + ): self._hybrid_system = hybrid_system self._hybrid_positions = hybrid_positions super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) @@ -167,7 +173,13 @@ class HybridRepexSampler(HybridCompatibilityMixin, number of positions """ - def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + def __init__( + self, + *args, + hybrid_system: openmm.System | None = None, + hybrid_positions: unit.Quantity | None = None, + **kwargs + ): super(HybridRepexSampler, self).__init__( *args, hybrid_system=hybrid_system, @@ -182,7 +194,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): of positions """ - def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + def __init__( + self, + *args, + hybrid_system: openmm.System | None = None, + hybrid_positions: unit.Quantity | None = None, + **kwargs + ): super(HybridSAMSSampler, self).__init__( *args, hybrid_system=hybrid_system, @@ -197,7 +215,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin, MultiStateSampler that supports unsample end states with a different number of positions """ - def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + def __init__( + self, + *args, + hybrid_system: openmm.System | None = None, + hybrid_positions: unit.Quantity | None = None, + **kwargs + ): super(HybridMultiStateSampler, self).__init__( *args, hybrid_system=hybrid_system, diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_units.py b/src/openfe/protocols/openmm_rfe/hybridtop_units.py index cd07d598b..a113b8a67 100644 --- a/src/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/src/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -32,6 +32,7 @@ SmallMoleculeComponent, SolventComponent, ) +from gufe.protocols.errors import ProtocolUnitExecutionError from gufe.settings import ( SettingsBaseModel, ThermoSettings, @@ -43,6 +44,7 @@ from openmmforcefields.generators import SystemGenerator from openmmtools import multistate +import openfe from openfe.protocols.openmm_utils.omm_settings import ( BasePartialChargeSettings, ) @@ -143,6 +145,22 @@ def _get_settings( protocol_settings["engine_settings"] = settings.engine_settings return protocol_settings + @staticmethod + def _verify_execution_environment( + setup_outputs: dict[str, Any], + ) -> None: + """ + Check that the Python environment hasn't changed based on the + relevant Python library versions stored in the setup outputs. + """ + if ( + (gufe.__version__ != setup_outputs["gufe_version"]) + or (openfe.__version__ != setup_outputs["openfe_version"]) + or (openmm.__version__ != setup_outputs["openmm_version"]) + ): + errmsg = "Python environment has changed, cannot continue Protocol execution." + raise ProtocolUnitExecutionError(errmsg) + class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin): """ @@ -781,6 +799,9 @@ def run( "positions": positions_outfile, "pdb_structure": self.shared_basepath / settings["output_settings"].output_structure, "selection_indices": selection_indices, + "openmm_version": openmm.__version__, + "openfe_version": openfe.__version__, + "gufe_version": gufe.__version__, } if dry: @@ -815,6 +836,44 @@ class HybridTopologyMultiStateSimulationUnit(gufe.ProtocolUnit, HybridTopologyUn replica exchange) unit for Hybrid Topology Protocol transformations. """ + @staticmethod + def _check_restart(output_settings: SettingsBaseModel, shared_path: pathlib.Path): + """ + Check if we are doing a restart. + + Parameters + ---------- + output_settings : SettingsBaseModel + The simulation output settings + shared_path : pathlib.Path + The shared directory where we should be looking for existing files. + + Notes + ----- + For now this just checks if the netcdf files are present in the + shared directory but in the future this may expand depending on + how warehouse works. + + Raises + ------ + IOError + If either the checkpoint or trajectory files don't exist. + """ + trajectory = shared_path / output_settings.output_filename + checkpoint = shared_path / output_settings.checkpoint_storage_filename + + if trajectory.is_file() ^ checkpoint.is_file(): + errmsg = ( + "One of either the trajectory or checkpoint files are missing but " + "the other is not. This should not happen under normal circumstances." + ) + raise IOError(errmsg) + + if trajectory.is_file() and checkpoint.is_file(): + return True + + return False + @staticmethod def _get_integrator( integrator_settings: IntegratorSettings, @@ -890,8 +949,16 @@ def _get_reporter( Settings defining how outputs should be written. simulation_settings : MultiStateSimulationSettings Settings defining out the simulation should be run. + + Notes + ----- + All this does is create the reporter, it works for both + new reporters and if we are doing a restart. """ + # Define the trajectory & checkpoint files nc = storage_path / output_settings.output_filename + # The checkpoint file in openmmtools is taken as a file relative + # to the location of the nc file, so you only want the filename chk = output_settings.checkpoint_storage_filename if output_settings.positions_write_frequency is not None: @@ -939,6 +1006,7 @@ def _get_sampler( thermo_settings: ThermoSettings, alchem_settings: AlchemicalSettings, platform: openmm.Platform, + restart: bool, dry: bool, ) -> multistate.MultiStateSampler: """ @@ -964,6 +1032,8 @@ def _get_sampler( The alchemical transformation settings. platform : openmm.Platform The compute platform to use. + restart : bool + ``True`` if we are doing a simulation restart. dry : bool Whether or not this is a dry run. @@ -972,10 +1042,31 @@ def _get_sampler( sampler : multistate.MultiStateSampler The requested sampler. """ + _SAMPLERS = { + "repex": _rfe_utils.multistate.HybridRepexSampler, + "sams": _rfe_utils.multistate.HybridSAMSSampler, + "independent": _rfe_utils.multistate.HybridMultiStateSampler, + } + + sampler_method = simulation_settings.sampler_method.lower() + + # Get the real time analysis values to use rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( simulation_settings=simulation_settings, ) + # Get the number of production iterations to run for + steps_per_iteration = integrator.n_steps + timestep = from_openmm(integrator.timestep) + number_of_iterations = int( + settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=timestep, + mc_steps=steps_per_iteration, + ) + / steps_per_iteration + ) + # convert early_termination_target_error from kcal/mol to kT early_termination_target_error = ( settings_validation.convert_target_error_from_kcal_per_mole_to_kT( @@ -984,51 +1075,57 @@ def _get_sampler( ) ) - if simulation_settings.sampler_method.lower() == "repex": - sampler = _rfe_utils.multistate.HybridRepexSampler( - mcmc_moves=integrator, - hybrid_system=system, - hybrid_positions=positions, - online_analysis_interval=rta_its, - online_analysis_target_error=early_termination_target_error, - online_analysis_minimum_iterations=rta_min_its, - ) + sampler_kwargs = { + "mcmc_moves": integrator, + "hybrid_system": system, + "hybrid_positions": positions, + "online_analysis_interval": rta_its, + "online_analysis_target_error": early_termination_target_error, + "online_analysis_minimum_iterations": rta_min_its, + "number_of_iterations": number_of_iterations, + } - elif simulation_settings.sampler_method.lower() == "sams": - sampler = _rfe_utils.multistate.HybridSAMSSampler( - mcmc_moves=integrator, - hybrid_system=system, - hybrid_positions=positions, - online_analysis_interval=rta_its, - online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=simulation_settings.sams_flatness_criteria, - gamma0=simulation_settings.sams_gamma0, - ) + if sampler_method == "sams": + sampler_kwargs |= { + "flatness_criteria": simulation_settings.sams_flatness_criteria, + "gamma0": simulation_settings.sams_gamma0, + } + + if sampler_method == "repex": + sampler_kwargs |= {"replica_mixing_scheme": "swap-all"} - elif simulation_settings.sampler_method.lower() == "independent": - sampler = _rfe_utils.multistate.HybridMultiStateSampler( - mcmc_moves=integrator, - hybrid_system=system, - hybrid_positions=positions, - online_analysis_interval=rta_its, - online_analysis_target_error=early_termination_target_error, - online_analysis_minimum_iterations=rta_min_its, + # Restarting doesn't need any setup, we just rebuild from storage. + if restart: + sampler = _SAMPLERS[sampler_method].from_storage(reporter) # type: ignore[attr-defined] + + # We do some checks to make sure we are running the same system + system_validation.assert_multistate_system_equality( + ref_system=system, + stored_system=sampler._thermodynamic_states[0].get_system(remove_thermostat=True), ) + if ( + (simulation_settings.n_replicas != sampler.n_states != sampler.n_replicas) + or (sampler.mcmc_moves[0].n_steps != steps_per_iteration) + or (sampler.mcmc_moves[0].timestep != integrator.timestep) + ): + errmsg = "Sampler in checkpoint does not match Protocol settings, cannot resume." + raise ValueError(errmsg) + else: - raise AttributeError(f"Unknown sampler {simulation_settings.sampler_method}") + sampler = _SAMPLERS[sampler_method](**sampler_kwargs) - sampler.setup( - n_replicas=simulation_settings.n_replicas, - reporter=reporter, - lambda_protocol=lambdas, - temperature=to_openmm(thermo_settings.temperature), - endstates=alchem_settings.endstate_dispersion_correction, - minimization_platform=platform.getName(), - # Set minimization steps to None when running in dry mode - # otherwise do a very small one to avoid NaNs - minimization_steps=100 if not dry else None, - ) + sampler.setup( + n_replicas=simulation_settings.n_replicas, + reporter=reporter, + lambda_protocol=lambdas, + temperature=to_openmm(thermo_settings.temperature), + endstates=alchem_settings.endstate_dispersion_correction, + minimization_platform=platform.getName(), + # Set minimization steps to None when running in dry mode + # otherwise do a very small one to avoid NaNs + minimization_steps=100 if not dry else None, + ) # Get and set the context caches sampler.energy_context_cache = openmmtools.cache.ContextCache( @@ -1089,23 +1186,28 @@ def _run_simulation( ) if not dry: # pragma: no-cover - # minimize - if self.verbose: - self.logger.info("minimizing systems") + # No productions steps have been taken, so start from scratch + if sampler._iteration == 0: + # minimize + if self.verbose: + self.logger.info("minimizing systems") - sampler.minimize(max_iterations=simulation_settings.minimization_steps) + sampler.minimize(max_iterations=simulation_settings.minimization_steps) - # equilibrate - if self.verbose: - self.logger.info("equilibrating systems") + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") - sampler.equilibrate(int(equil_steps / mc_steps)) + sampler.equilibrate(int(equil_steps / mc_steps)) - # production + # At this point we are ready for production if self.verbose: self.logger.info("running production phase") - sampler.extend(int(prod_steps / mc_steps)) + # We use `run` so that we're limited by the number of iterations + # we passed when we built the sampler. + # TODO: I'm being extra prudent by passing in n_iterations here - remove? + sampler.run(n_iterations=int(prod_steps / mc_steps) - sampler._iteration) if self.verbose: self.logger.info("production phase complete") @@ -1173,6 +1275,11 @@ def run( # Get the settings settings = self._get_settings(self._inputs["protocol"].settings) + # Check for a restart + self.restart = self._check_restart( + output_settings=settings["output_settings"], shared_path=self.shared_basepath + ) + # Get the lambda schedule # TODO - this should be better exposed to users lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( @@ -1204,7 +1311,7 @@ def run( simulation_settings=settings["simulation_settings"], ) - # Get sampler + # Get the sampler sampler = self._get_sampler( system=system, positions=positions, @@ -1215,9 +1322,11 @@ def run( thermo_settings=settings["thermo_settings"], alchem_settings=settings["alchemical_settings"], platform=platform, + restart=self.restart, dry=dry, ) + # Run the simulation self._run_simulation( sampler=sampler, reporter=reporter, @@ -1273,7 +1382,10 @@ def _execute( **inputs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - # Get the relevant inputs + # Ensure that we the environment hasn't changed + self._verify_execution_environment(setup_results.outputs) + + # Get the relevant inputs for running the unit system = deserialize(setup_results.outputs["system"]) positions = to_openmm(np.load(setup_results.outputs["positions"]) * offunit.nm) selection_indices = setup_results.outputs["selection_indices"] @@ -1501,6 +1613,9 @@ def _execute( ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) + # Ensure that we the environment hasn't changed + self._verify_execution_environment(setup_results.outputs) + pdb_file = setup_results.outputs["pdb_structure"] selection_indices = setup_results.outputs["selection_indices"] trajectory = simulation_results.outputs["nc"] diff --git a/src/openfe/protocols/openmm_utils/system_validation.py b/src/openfe/protocols/openmm_utils/system_validation.py index 3e8ed5c50..7b45a077a 100644 --- a/src/openfe/protocols/openmm_utils/system_validation.py +++ b/src/openfe/protocols/openmm_utils/system_validation.py @@ -7,6 +7,8 @@ from typing import Optional, Tuple +import numpy as np +import openmm from gufe import ( ChemicalSystem, Component, @@ -177,3 +179,102 @@ def _get_single_comps(state, comptype): small_mols = state.get_components_of_type(SmallMoleculeComponent) return solvent_comp, protein_comp, small_mols + + +def assert_multistate_system_equality( + ref_system: openmm.System, + stored_system: openmm.System, +): + """ + Verify the equality of a MultiStateReporter + stored system, with that of a pre-exisiting + standard system. + + + Raises + ------ + ValueError + * If the particles in the two System don't match. + * If the constraints in the two System don't match. + * If the forces in the two systems don't match. + """ + + # Assert particle equality + def _get_masses(system): + return np.array( + [ + system.getParticleMass(i).value_in_unit(openmm.unit.dalton) + for i in range(system.getNumParticles()) + ] + ) + + ref_masses = _get_masses(ref_system) + stored_masses = _get_masses(stored_system) + + if not ((ref_masses.shape == stored_masses.shape) and (np.allclose(ref_masses, stored_masses))): + errmsg = "Stored checkpoint System particles do not match those of the simulated System" + raise ValueError(errmsg) + + # Assert constraint equality + def _get_constraints(system): + constraints = [] + for index in range(system.getNumConstraints()): + i, j, d = system.getConstraintParameters(index) + constraints.append([i, j, d.value_in_unit(openmm.unit.nanometer)]) + + return np.array(constraints) + + ref_constraints = _get_constraints(ref_system) + stored_constraints = _get_constraints(stored_system) + + if not ( + (ref_constraints.shape == stored_constraints.shape) + and (np.allclose(ref_constraints, stored_constraints)) + ): + errmsg = "Stored checkpoint System constraints do not match those of the simulation System" + raise ValueError(errmsg) + + # Assert force equality + # Notes: + # * Store forces are in different order + # * The barostat doesn't exactly match because seeds have changed + + # Create dictionaries of forces keyed by their hash + # Note: we can't rely on names because they may clash + ref_force_dict = {hash(openmm.XmlSerializer.serialize(f)): f for f in ref_system.getForces()} + stored_force_dict = { + hash(openmm.XmlSerializer.serialize(f)): f for f in stored_system.getForces() + } + + # Assert the number of forces is equal + if len(ref_force_dict) != len(stored_force_dict): + errmsg = "Number of forces stored in checkpoint System does not match simulation System" + raise ValueError(errmsg) + + # Loop through forces and check for equality + for sfhash, sforce in stored_force_dict.items(): + errmsg = ( + f"Force {sforce.getName()} in the stored checkpoint System " + "does not match the same force in the simulated System" + ) + + # Barostat case - seed changed so we need to check manually + if isinstance(sforce, (openmm.MonteCarloBarostat, openmm.MonteCarloMembraneBarostat)): + # Find the equivalent force in the reference + rforce = [ + f + for f in ref_force_dict.values() + if isinstance(f, (openmm.MonteCarloBarostat, openmm.MonteCarloMembraneBarostat)) + ][0] + + if ( + (sforce.getFrequency() != rforce.getFrequency()) + or (sforce.getForceGroup() != rforce.getForceGroup()) + or (sforce.getDefaultPressure() != rforce.getDefaultPressure()) + or (sforce.getDefaultTemperature() != rforce.getDefaultTemperature()) + ): + raise ValueError(errmsg) + + else: + if sfhash not in ref_force_dict: + raise ValueError(errmsg) diff --git a/src/openfe/tests/conftest.py b/src/openfe/tests/conftest.py index deaad59e3..e3c5c84d4 100644 --- a/src/openfe/tests/conftest.py +++ b/src/openfe/tests/conftest.py @@ -12,7 +12,6 @@ import numpy as np import openmm import pandas as pd -import pooch import pytest from gufe import AtomMapper, LigandAtomMapping, ProteinComponent, SmallMoleculeComponent from openff.toolkit import ForceField diff --git a/src/openfe/tests/protocols/conftest.py b/src/openfe/tests/protocols/conftest.py index b5f302947..6978148a4 100644 --- a/src/openfe/tests/protocols/conftest.py +++ b/src/openfe/tests/protocols/conftest.py @@ -20,6 +20,7 @@ from openfe.data._registry import ( POOCH_CACHE, zenodo_industry_benchmark_systems, + zenodo_resume_data, zenodo_rfe_simulation_nc, zenodo_t4_lysozyme_traj, ) @@ -334,6 +335,31 @@ def simulation_nc(): ) +pooch_resume_data = pooch.create( + path=POOCH_CACHE, + base_url=zenodo_resume_data["base_url"], + registry={zenodo_resume_data["fname"]: zenodo_resume_data["known_hash"]}, +) + + +@pytest.fixture(scope="session") +def htop_trajectory_path(): + pooch_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "hybrid_top" + filename = "simulation.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="session") +def htop_checkpoint_path(): + pooch_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "hybrid_top" + filename = "checkpoint.chk" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + @pytest.fixture def get_available_openmm_platforms() -> set[str]: """ diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index bd7a1f72f..5811ea973 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -12,6 +12,7 @@ import gufe import mdtraj as mdt import numpy as np +import openmm import pytest from kartograf import KartografAtomMapper from kartograf.atom_aligner import align_mol_shape @@ -1154,9 +1155,8 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp ) -def test_unit_tagging(solvent_protocol_dag, tmpdir): - # test that executing the Units includes correct generation and repeat info - dag_units = solvent_protocol_dag.protocol_units +@pytest.fixture() +def unit_mock_patcher(): with ( mock.patch( "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologySetupUnit.run", @@ -1165,6 +1165,9 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): "positions": Path("positions.npy"), "pdb_structure": Path("hybrid_system.pdb"), "selection_indices": np.zeros(100), + "gufe_version": gufe.__version__, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, }, ), mock.patch( @@ -1191,31 +1194,39 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): }, ), ): - setup_results = {} - sim_results = {} - analysis_results = {} - - setup_units = _get_units(dag_units, HybridTopologySetupUnit) - sim_units = _get_units(dag_units, HybridTopologyMultiStateSimulationUnit) - analysis_units = _get_units(dag_units, HybridTopologyMultiStateAnalysisUnit) - - for u in setup_units: - rid = u.inputs["repeat_id"] - setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir)) - - for u in sim_units: - rid = u.inputs["repeat_id"] - sim_results[rid] = u.execute( - context=gufe.Context(tmpdir, tmpdir), setup_results=setup_results[rid] - ) + yield + + +def test_unit_tagging(solvent_protocol_dag, unit_mock_patcher, tmpdir): + # test that executing the Units includes correct generation and repeat info + dag_units = solvent_protocol_dag.protocol_units + + setup_results = {} + sim_results = {} + analysis_results = {} + + setup_units = _get_units(dag_units, HybridTopologySetupUnit) + sim_units = _get_units(dag_units, HybridTopologyMultiStateSimulationUnit) + analysis_units = _get_units(dag_units, HybridTopologyMultiStateAnalysisUnit) + + for u in setup_units: + rid = u.inputs["repeat_id"] + setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir)) + + for u in sim_units: + rid = u.inputs["repeat_id"] + sim_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), setup_results=setup_results[rid] + ) + + for u in analysis_units: + rid = u.inputs["repeat_id"] + analysis_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), + setup_results=setup_results[rid], + simulation_results=sim_results[rid], + ) - for u in analysis_units: - rid = u.inputs["repeat_id"] - analysis_results[rid] = u.execute( - context=gufe.Context(tmpdir, tmpdir), - setup_results=setup_results[rid], - simulation_results=sim_results[rid], - ) for results in [setup_results, sim_results, analysis_results]: for ret in results.values(): assert isinstance(ret, gufe.ProtocolUnitResult) @@ -1225,48 +1236,14 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): assert len(setup_results) == len(sim_results) == len(analysis_results) == 3 -def test_gather(solvent_protocol_dag, tmpdir): +def test_gather(solvent_protocol_dag, unit_mock_patcher, tmpdir): # check .gather behaves as expected - with ( - mock.patch( - "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologySetupUnit.run", - return_value={ - "system": Path("system.xml.bz2"), - "positions": Path("positions.npy"), - "pdb_structure": Path("hybrid_system.pdb"), - "selection_indices": np.zeros(100), - }, - ), - mock.patch( - "openfe.protocols.openmm_rfe.hybridtop_units.np.load", - return_value=np.zeros(100), - ), - mock.patch( - "openfe.protocols.openmm_rfe.hybridtop_units.deserialize", - return_value={ - "item": "foo", - }, - ), - mock.patch( - "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateSimulationUnit.run", - return_value={ - "nc": Path("file.nc"), - "checkpoint": Path("chk.chk"), - }, - ), - mock.patch( - "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateAnalysisUnit.run", - return_value={ - "foo": "bar", - }, - ), - ): - dagres = gufe.protocols.execute_DAG( - solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, - keep_shared=True, - ) + dagres = gufe.protocols.execute_DAG( + solvent_protocol_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True, + ) prot = openmm_rfe.RelativeHybridTopologyProtocol( settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings() diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py new file mode 100644 index 000000000..dcfbabb37 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_resume.py @@ -0,0 +1,450 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import copy +import os +import pathlib +import shutil + +import gufe +import numpy as np +import openmm +import pooch +import pytest +from gufe.protocols import execute_DAG +from gufe.protocols.errors import ProtocolUnitExecutionError +from numpy.testing import assert_allclose +from openfe_analysis.utils.multistate import _determine_position_indices +from openff.units import unit as offunit +from openff.units.openmm import from_openmm +from openmmtools.multistate import MultiStateReporter + +import openfe +from openfe.data._registry import POOCH_CACHE +from openfe.protocols import openmm_rfe +from openfe.protocols.openmm_rfe._rfe_utils.multistate import HybridRepexSampler +from openfe.protocols.openmm_rfe.hybridtop_units import ( + HybridTopologyMultiStateAnalysisUnit, + HybridTopologyMultiStateSimulationUnit, + HybridTopologySetupUnit, +) + +from ...conftest import HAS_INTERNET +from .test_hybrid_top_protocol import _get_units + + +@pytest.fixture() +def protocol_settings(): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + settings.solvation_settings.solvent_padding = None + settings.solvation_settings.number_of_solvent_molecules = 750 + settings.solvation_settings.box_shape = "dodecahedron" + settings.protocol_repeats = 1 + settings.simulation_settings.equilibration_length = 100 * offunit.picosecond + settings.simulation_settings.production_length = 200 * offunit.picosecond + settings.simulation_settings.time_per_iteration = 2.5 * offunit.picosecond + settings.output_settings.checkpoint_interval = 100 * offunit.picosecond + settings.engine_settings.compute_platform = None + return settings + + +def test_verify_execution_environment(): + # Verification should pass + openmm_rfe.HybridTopologyMultiStateSimulationUnit._verify_execution_environment( + setup_outputs={ + "gufe_version": gufe.__version__, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, + }, + ) + + +def test_verify_execution_environment_fail(): + # Passing a bad version should fail + with pytest.raises(ProtocolUnitExecutionError, match="Python environment"): + openmm_rfe.HybridTopologyMultiStateSimulationUnit._verify_execution_environment( + setup_outputs={ + "gufe_version": 0.1, + "openfe_version": openfe.__version__, + "openmm_version": openmm.__version__, + }, + ) + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +def test_check_restart(protocol_settings, htop_trajectory_path): + assert openmm_rfe.HybridTopologyMultiStateSimulationUnit._check_restart( + output_settings=protocol_settings.output_settings, + shared_path=htop_trajectory_path.parent, + ) + + assert not openmm_rfe.HybridTopologyMultiStateSimulationUnit._check_restart( + output_settings=protocol_settings.output_settings, + shared_path=pathlib.Path("."), + ) + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +class TestCheckpointResuming: + @pytest.fixture() + def protocol_dag( + self, protocol_settings, benzene_system, toluene_system, benzene_to_toluene_mapping + ): + protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=protocol_settings) + + return protocol.create( + stateA=benzene_system, stateB=toluene_system, mapping=benzene_to_toluene_mapping + ) + + @staticmethod + def _check_sampler(sampler, num_iterations: int): + # Helper method to do some checks on the sampler + assert sampler._iteration == num_iterations + assert sampler.number_of_iterations == 80 + assert sampler.is_completed is (num_iterations == 80) + assert sampler.n_states == sampler.n_replicas == 11 + assert sampler.is_periodic + assert sampler.mcmc_moves[0].n_steps == 625 + assert from_openmm(sampler.mcmc_moves[0].timestep) == 4 * offunit.fs + + @staticmethod + def _get_positions(dataset): + frame_list = _determine_position_indices(dataset) + positions = [] + for frame in frame_list: + positions.append(copy.deepcopy(dataset.variables["positions"][frame].data)) + return positions + + @staticmethod + def _copy_simfiles(cwd: pathlib.Path, filepath): + shutil.copyfile(filepath, f"{cwd}/{filepath.name}") + + @pytest.mark.integration + def test_resume(self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, tmpdir): + """ + Attempt to resume a simulation unit with pre-existing checkpoint & + trajectory files. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, htop_trajectory_path) + self._copy_simfiles(cwd, htop_checkpoint_path) + + # 1. Check that the trajectory / checkpoint contain what we expect + reporter = MultiStateReporter( + f"{cwd}/simulation.nc", + checkpoint_storage="checkpoint.chk", + ) + sampler = HybridRepexSampler.from_storage(reporter) + + self._check_sampler(sampler, num_iterations=40) + # Deep copy energies & positions for later tests + init_energies = copy.deepcopy(reporter.read_energies())[0] + assert init_energies.shape == (41, 11, 11) + init_positions = self._get_positions(reporter._storage[0]) + assert len(init_positions) == 2 + + reporter.close() + del sampler + + # 2. get & run the units + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Now we run the simulation in resume mode + sim_results = simulation_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + # TODO: can't do this right now: openfe-analysis isn't closing + # netcdf files properly, so we can't do any follow-up operations + # Once openfe-analysis is released, add tests for this. + # # Finally we analyze the results + # analysis_results = analysis_unit.run( + # pdb_file=setup_results["pdb_structure"], + # trajectory=sim_results["nc"], + # checkpoint=sim_results["checkpoint"], + # scratch_basepath=cwd, + # shared_basepath=cwd, + # ) + + # 3. Analyze the trajectory/checkpoint again + reporter = MultiStateReporter( + f"{cwd}/simulation.nc", + checkpoint_storage="checkpoint.chk", + ) + sampler = HybridRepexSampler.from_storage(reporter) + + self._check_sampler(sampler, num_iterations=80) + + # Check the energies and positions + energies = reporter.read_energies()[0] + assert energies.shape == (81, 11, 11) + assert_allclose(init_energies, energies[:41]) + + positions = self._get_positions(reporter._storage[0]) + assert len(positions) == 3 + for i in range(2): + assert_allclose(positions[i], init_positions[i]) + + reporter.close() + del sampler + + @pytest.mark.slow + def test_resume_fail_particles( + self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check that we don't have the same particles / mass. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, htop_trajectory_path) + self._copy_simfiles(cwd, htop_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Fake system should trigger a mismatch + errmsg = "Stored checkpoint System particles do not" + with pytest.raises(ValueError, match=errmsg): + sim_results = simulation_unit.run( + system=openmm.System(), + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + def test_resume_fail_constraints( + self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check that we don't have the same constraints. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, htop_trajectory_path) + self._copy_simfiles(cwd, htop_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system without constraints + fake_system = copy.deepcopy(setup_results["hybrid_system"]) + + for i in reversed(range(fake_system.getNumConstraints())): + fake_system.removeConstraint(i) + + # Fake system should trigger a mismatch + errmsg = "Stored checkpoint System constraints do not" + with pytest.raises(ValueError, match=errmsg): + sim_results = simulation_unit.run( + system=fake_system, + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + def test_resume_fail_forces( + self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check we don't have the same forces. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, htop_trajectory_path) + self._copy_simfiles(cwd, htop_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system without the last force + fake_system = copy.deepcopy(setup_results["hybrid_system"]) + fake_system.removeForce(fake_system.getNumForces() - 1) + + # Fake system should trigger a mismatch + errmsg = "Number of forces stored in checkpoint System" + with pytest.raises(ValueError, match=errmsg): + sim_results = simulation_unit.run( + system=fake_system, + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + @pytest.mark.parametrize("forcetype", [openmm.NonbondedForce, openmm.MonteCarloBarostat]) + def test_resume_differ_forces( + self, forcetype, protocol_dag, htop_trajectory_path, htop_checkpoint_path, tmpdir + ): + """ + Test that the run unit will fail with a system incompatible + to the one present in the trajectory/checkpoint files. + + Here we check we have a different force + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + self._copy_simfiles(cwd, htop_trajectory_path) + self._copy_simfiles(cwd, htop_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + # Create a fake system with the fake forcetype + fake_system = copy.deepcopy(setup_results["hybrid_system"]) + + # Loop through forces and remove the force matching forcetype + for i, f in enumerate(fake_system.getForces()): + if isinstance(f, forcetype): + findex = i + + fake_system.removeForce(findex) + + # Now add a fake force + if forcetype == openmm.MonteCarloBarostat: + new_force = forcetype(1 * openmm.unit.atmosphere, 300 * openmm.unit.kelvin, 100) + else: + new_force = forcetype() + + fake_system.addForce(new_force) + + # Fake system should trigger a mismatch + errmsg = "stored checkpoint System does not match the same force" + with pytest.raises(ValueError, match=errmsg): + sim_results = simulation_unit.run( + system=fake_system, + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + @pytest.mark.parametrize("bad_file", ["trajectory", "checkpoint"]) + def test_resume_bad_files( + self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, bad_file, tmpdir + ): + """ + Test what happens when you have a bad trajectory and/or checkpoint + files. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + + if bad_file == "trajectory": + with open(f"{cwd}/simulation.nc", "w") as f: + f.write("foo") + else: + self._copy_simfiles(cwd, htop_trajectory_path) + + if bad_file == "checkpoint": + with open(f"{cwd}/checkpoint.chk", "w") as f: + f.write("bar") + else: + self._copy_simfiles(cwd, htop_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + with pytest.raises(OSError, match="Unknown file format"): + sim_results = simulation_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) + + @pytest.mark.slow + @pytest.mark.parametrize("missing_file", ["trajectory", "checkpoint"]) + def test_missing_file( + self, protocol_dag, htop_trajectory_path, htop_checkpoint_path, missing_file, tmpdir + ): + """ + Test that an error is thrown if either file is missing but the other isn't. + """ + # define a temp directory path & copy files + cwd = pathlib.Path(str(tmpdir)) + + if missing_file == "trajectory": + pass + else: + self._copy_simfiles(cwd, htop_trajectory_path) + + if missing_file == "checkpoint": + pass + else: + self._copy_simfiles(cwd, htop_checkpoint_path) + + pus = list(protocol_dag.protocol_units) + setup_unit = _get_units(pus, HybridTopologySetupUnit)[0] + simulation_unit = _get_units(pus, HybridTopologyMultiStateSimulationUnit)[0] + analysis_unit = _get_units(pus, HybridTopologyMultiStateAnalysisUnit)[0] + + # Dry run the setup since it'll be easier to use the objects directly + setup_results = setup_unit.run(dry=True, scratch_basepath=cwd, shared_basepath=cwd) + + errmsg = "One of either the trajectory or checkpoint files are missing" + with pytest.raises(IOError, match=errmsg): + sim_results = simulation_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + scratch_basepath=cwd, + shared_basepath=cwd, + ) diff --git a/src/openfe/tests/protocols/restraints/test_geometry_boresch.py b/src/openfe/tests/protocols/restraints/test_geometry_boresch.py index 5231deef5..d650333b6 100644 --- a/src/openfe/tests/protocols/restraints/test_geometry_boresch.py +++ b/src/openfe/tests/protocols/restraints/test_geometry_boresch.py @@ -4,7 +4,6 @@ import pathlib import MDAnalysis as mda -import pooch import pytest from openff.units import unit from rdkit import Chem @@ -15,7 +14,7 @@ find_boresch_restraint, ) -from ...conftest import HAS_INTERNET, POOCH_CACHE +from ...conftest import HAS_INTERNET @pytest.fixture() diff --git a/src/openfe/tests/protocols/restraints/test_geometry_boresch_host.py b/src/openfe/tests/protocols/restraints/test_geometry_boresch_host.py index 5f89cf5aa..e5437909f 100644 --- a/src/openfe/tests/protocols/restraints/test_geometry_boresch_host.py +++ b/src/openfe/tests/protocols/restraints/test_geometry_boresch_host.py @@ -5,7 +5,6 @@ import MDAnalysis as mda import numpy as np -import pooch import pytest from numpy.testing import assert_equal from openff.units import unit diff --git a/src/openfe/tests/protocols/restraints/test_geometry_utils.py b/src/openfe/tests/protocols/restraints/test_geometry_utils.py index 077c32c24..454147cc7 100644 --- a/src/openfe/tests/protocols/restraints/test_geometry_utils.py +++ b/src/openfe/tests/protocols/restraints/test_geometry_utils.py @@ -31,7 +31,7 @@ stable_secondary_structure_selection, ) -from ...conftest import HAS_INTERNET, POOCH_CACHE +from ...conftest import HAS_INTERNET @pytest.fixture(scope="module")