From 6a7e8b7240ae844de43b16cda7b25f8d9698f9d7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 22 Jan 2026 18:34:41 +0000 Subject: [PATCH 1/3] A start --- .../protocols/openmm_afe/base_afe_units.py | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/openfe/protocols/openmm_afe/base_afe_units.py b/src/openfe/protocols/openmm_afe/base_afe_units.py index 4095982ab..20ea73fe3 100644 --- a/src/openfe/protocols/openmm_afe/base_afe_units.py +++ b/src/openfe/protocols/openmm_afe/base_afe_units.py @@ -787,6 +787,32 @@ def _execute( class BaseAbsoluteMultiStateSimulationUnit(gufe.ProtocolUnit, AbsoluteUnitMixin): + @staticmethod + def _check_restart(output_settings: SettingsBaseModel, shared_path: pathlib.Path): + """ + Check if we are doing a restart. + + Parameters + ---------- + output_settings : SettingsBaseModel + The simulation output settings + shared_path : pathlib.Path + The shared directory where we should be looking for existing files. + + Notes + ----- + For now this just checks if the netcdf files are present in the + shared directory but in the future this may expand depending on + how warehouse works. + """ + trajectory = shared_path / output_settings.output_filename + checkpoint = shared_path / output_settings.checkpoint_storage_filename + + if trajectory.is_file() and checkpoint.is_file(): + return True + + return False + @abc.abstractmethod def _get_components( self, @@ -1034,7 +1060,7 @@ def _get_reporter( time_per_iteration=simulation_settings.time_per_iteration, ) - reporter = multistate.MultiStateReporter( + return multistate.MultiStateReporter( storage=nc, analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, @@ -1043,8 +1069,6 @@ def _get_reporter( velocity_interval=vel_interval, ) - return reporter - @staticmethod def _get_sampler( integrator: openmmtools.mcmc.LangevinDynamicsMove, From 3c97de5ec920995598a7557b497f585c0022e728 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 22 Jan 2026 21:34:25 +0000 Subject: [PATCH 2/3] code changes to support resuming in AFE protocols --- .../protocols/openmm_afe/base_afe_units.py | 180 ++++++++++++------ 1 file changed, 119 insertions(+), 61 deletions(-) diff --git a/src/openfe/protocols/openmm_afe/base_afe_units.py b/src/openfe/protocols/openmm_afe/base_afe_units.py index 20ea73fe3..2ad67a4d2 100644 --- a/src/openfe/protocols/openmm_afe/base_afe_units.py +++ b/src/openfe/protocols/openmm_afe/base_afe_units.py @@ -1078,6 +1078,7 @@ def _get_sampler( compound_states: list[ThermodynamicState], sampler_states: list[SamplerState], platform: openmm.Platform, + restart: bool, ) -> multistate.MultiStateSampler: """ Get a sampler based on the equilibrium sampling method requested. @@ -1098,51 +1099,93 @@ def _get_sampler( A list of sampler states. platform : openmm.Platform The compute platform to use. + restart : bool + ``True`` if we are doing a simulation restart. Returns ------- sampler : multistate.MultistateSampler A sampler configured for the chosen sampling method. """ + _SAMPLERS = { + "repex": multistate.ReplicaExchangeSampler, + "sams": multistate.SAMSSampler, + "independent": multistate.MultiStateSampler, + } + + sampler_method = simulation_settings.sampler_method.lower() + + # Get the real time analysis values to use rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( simulation_settings=simulation_settings, ) - et_target_err = settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermodynamic_settings.temperature, - simulation_settings.early_termination_target_error, - ) - # Select the right sampler - # Note: doesn't need else, settings already validates choices - if simulation_settings.sampler_method.lower() == "repex": - sampler = multistate.ReplicaExchangeSampler( - mcmc_moves=integrator, - online_analysis_interval=rta_its, - online_analysis_target_error=et_target_err, - online_analysis_minimum_iterations=rta_min_its, - ) - elif simulation_settings.sampler_method.lower() == "sams": - sampler = multistate.SAMSSampler( - mcmc_moves=integrator, - online_analysis_interval=rta_its, - online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=simulation_settings.sams_flatness_criteria, - gamma0=simulation_settings.sams_gamma0, - ) - elif simulation_settings.sampler_method.lower() == "independent": - sampler = multistate.MultiStateSampler( - mcmc_moves=integrator, - online_analysis_interval=rta_its, - online_analysis_target_error=et_target_err, - online_analysis_minimum_iterations=rta_min_its, + # Get the number of production iterations to run for + steps_per_iteration = integrator.n_steps + timestep = from_openmm(integrator.timestep) + number_of_iterations = int( + settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=timestep, + mc_steps=steps_per_iteration, ) + / steps_per_iteration + ) - sampler.create( - thermodynamic_states=compound_states, - sampler_states=sampler_states, - storage=reporter, + # convert early_termination_target_error from kcal/mol to kT + early_termination_target_error = ( + settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermodynamic_settings.temperature, + simulation_settings.early_termination_target_error, + ) ) + sampler_kwargs = { + "mcmc_moves": integrator, + "online_analysis_interval": rta_its, + "online_analysis_target_error": early_termination_target_error, + "online_analysis_minimum_iterations": rta_min_its, + "number_of_iterations": number_of_iterations, + } + + if sampler_method == "sams": + sampler_kwargs |= { + "flatness_criteria": simulation_settings.sams_flatness_criteria, + "gamma0": simulation_settings.sams_gamma0, + } + + if sampler_method == "repex": + sampler_kwargs |= { + "replica_mixing_scheme": "swap-all", + } + + if restart: + sampler = _SAMPLERS[sampler_method].from_storage(reporter) + + # Add some tests here + sampler_system = sampler._thermodynamic_states[0].get_system(remove_thermostat=True) + system = compound_states[0].get_system(rermove_thermostat=True) + + if ( + (simulation_settings.n_replicas != sampler.n_states != sampler.n_replicas) + or (system.getNumForces() != sampler_system.getNumForces()) + or (system.getNumParticles() != sampler_system.getNumParticles()) + or (system.getNumConstraints() != sampler_system.getNumConstraints()) + or (sampler.mcmc_moves[0].n_steps != steps_per_iteration) + or (sampler.mcmc_moves[0].timestep != integrator.timestep) + ): + errmsg = "System in checkpoint does not match protocol system, cannot resume" + raise ValueError(errmsg) + else: + sampler = _SAMPLERS[sampler_method](**sampler_kwargs) + + sampler.create( + thermodynamic_states=compound_states, + sampler_states=sampler_states, + storage=reporter, + ) + + # Get and set the context caches sampler.energy_context_cache = openmmtools.cache.ContextCache( capacity=None, time_to_live=None, @@ -1196,22 +1239,27 @@ def _run_simulation( ) if not dry: # pragma: no-cover - # minimize - if self.verbose: - self.logger.info("minimizing systems") + # No production steps have been taken, so start from scratch + if sampler._iteration == 0: + # minimize + if self.verbose: + self.logger.info("minimizing systems") - sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) + sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) - # equilibrate - if self.verbose: - self.logger.info("equilibrating systems") + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") - sampler.equilibrate(int(equil_steps / mc_steps)) + sampler.equilibrate(int(equil_steps / mc_steps)) - # production + # St this point we are ready for production if self.verbose: self.logger.info("running production phase") - sampler.extend(int(prod_steps / mc_steps)) + + # We use `run` so that we're limited by the number of iterations + # we passed when we built the sampler. + sampler.run(n_iterations=int(prod_steps / mc_steps) - sampler._iteration) if self.verbose: self.logger.info("production phase complete") @@ -1281,6 +1329,12 @@ def run( # Get the settings settings = self._get_settings() + # Check for a restart + self.restart = self._check_restart( + output_settings=settings["output_settings"], + shared_path=self.shared_basepath, + ) + # Get the components alchem_comps, solv_comp, prot_comp, small_mols = self._get_components() @@ -1323,7 +1377,7 @@ def run( output_settings=settings["output_settings"], ) - # Get sampler + # Get the sampler sampler = self._get_sampler( integrator=integrator, reporter=reporter, @@ -1332,9 +1386,10 @@ def run( compound_states=cmp_states, sampler_states=sampler_states, platform=platform, + restart=self.restart, ) - # Run simulation + # Run the simulation self._run_simulation( sampler=sampler, reporter=reporter, @@ -1343,24 +1398,27 @@ def run( ) finally: - # close reporter when you're done to prevent file handle clashes - reporter.close() - - # clear GPU context - # Note: use cache.empty() when openmmtools #690 is resolved - for context in list(sampler.energy_context_cache._lru._data.keys()): - del sampler.energy_context_cache._lru._data[context] - for context in list(sampler.sampler_context_cache._lru._data.keys()): - del sampler.sampler_context_cache._lru._data[context] - # cautiously clear out the global context cache too - for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): - del openmmtools.cache.global_context_cache._lru._data[context] - - del sampler.sampler_context_cache, sampler.energy_context_cache - - # Keep these around in a dry run so we can inspect things - if not dry: - del integrator, sampler + # Order is reporter, sampler, and then integrator + try: + reporter.close() # close to prevent file handle clashes + + # clear GPU context + # Note: use cache.empty() when openmmtools #690 is resolved + for context in list(sampler.energy_context_cache._lru._data.keys()): + del sampler.energy_context_cache._lru._data[context] + for context in list(sampler.sampler_context_cache._lru._data.keys()): + del sampler.sampler_context_cache._lru._data[context] + # cautiously clear out the global context cache too + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): + del openmmtools.cache.global_context_cache._lru._data[context] + + del sampler.sampler_context_cache, sampler.energy_context_cache + + # Keep these around in a dry run so we can inspect things + if not dry: + del integrator, sampler + except UnboundLocalError: + pass if not dry: nc = self.shared_basepath / settings["output_settings"].output_filename From 104f850bd3c76ed51c4a50ea420b91e78cc92344 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sun, 25 Jan 2026 20:02:16 +0000 Subject: [PATCH 3/3] some edits --- .../protocols/openmm_ahfe/test_ahfe_resume.py | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py new file mode 100644 index 000000000..a78261d0d --- /dev/null +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_resume.py @@ -0,0 +1,197 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pathlib +import pooch + +import pytest +from openff.units import unit as offunit + +import openfe +from openfe.protocols import openmm_afe + +from ...conftest import HAS_INTERNET +from utils import _get_units + + +POOCH_CACHE = pooch.os_cache("openfe") +zenodo_resume_data = pooch.create( + path=POOCH_CACHE, + base_url="doi:10.5281/zenodo.18331259", + registry={"multistate_checkpoints.zip": "md5:2cf8aa417ac8311aca1551d4abf3b3ed"}, +) + +@pytest.fixture(scope="module") +def vac_trajectory_path(): + zenodo_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "vacuum.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="module") +def vac_checkpoint_path(): + zenodo_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "vacuum_checkpoint.chk" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="module") +def sol_trajectory_path(): + zenodo_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "solvent.nc" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture(scope="module") +def sol_checkpoint_path(): + zenodo_resume_data.fetch("multistate_checkpoints.zip", processor=pooch.Unzip()) + topdir = "multistate_checkpoints.zip.unzip/multistate_checkpoints" + subdir = "ahfes" + filename = "solvent_checkpoint.chk" + return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}") + + +@pytest.fixture() +def protocol_settings(): + settings = openmm_afe.AbsoluteSolvationProtocol.default_settings() + settings.protocol_repeats = 1 + settings.solvent_output_settings.output_indices = "resname UNK" + settings.solvation_settings.solvent_padding = None + settings.solvation_settings.number_of_solvent_molecules = 750 + settings.solvation_settings.box_shape = "dodecahedron" + settings.vacuum_simulation_settings.equilibration_length = 100 * offunit.picosecond + settings.vacuum_simulation_settings.production_length = 200 * offunit.picosecond + settings.solvent_simulation_settings.equilibration_length = 100 * offunit.picosecond + settings.solvent_simulation_settings.production_length = 200 * offunit.picosecond + settings.vacuum_engine_settings.compute_platform = "CUDA" + settings.solvent_engine_settings.compute_platform = "CUDA" + settings.vacuum_simulation_settings.time_per_iteration = 2.5 * offunit.picosecond + settings.solvent_simulation_settings.time_per_iteration = 2.5 * offunit.picosecond + settings.vacuum_output_settings.checkpoint_interval = 100 * offunit.picosecond + settings.solvent_output_settings.checkpoint_interval = 100 * offunit.picosecond + return settings + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +def test_solvent_check_restart(protocol_settings, sol_trajectory_path): + assert openmm_afe.ABFESolventSimUnit._check_restart( + output_settings=protocol_settings.solvent_output_settings, + shared_path=sol_trajectory_path.parent, + ) + + assert not openmm_afe.ABFESolventSimUnit._check_restart( + output_settings=protocol_settings.solvent_output_settings, + shared_path=pathlib.Path("."), + ) + + +@pytest.mark.skipif( + not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, + reason="Internet unavailable and test data is not cached locally", +) +def test_vacuum_check_restart(protocol_settings, vac_trajectory_path): + assert openmm_afe.ABFEVacuumSimUnit._check_restart( + output_settings=protocol_settings.vacuum_output_settings, + shared_path=vac_trajectory_path.parent, + ) + + assert not openmm_afe.ABFEVacuumSimUnit._check_restart( + output_settings=protocol_settings.vacuum_output_settings, + shared_path=pathlib.Path("."), + ) + + + +class TestCheckpointResuming: + @pytest.fixture() + def protocol_dag( + self, protocol_settings, benzene_modifications, + ): + stateA = openfe.ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "solvent": openfe.SolventComponent(), + } + ) + + stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) + + protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_settings) + + # Create DAG from protocol, get the vacuum and solvent units + # and eventually dry run the first solvent unit + return protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + + def test_resume(self, protocol_dag, tmpdir): + """ + Attempt to resume a simulation unit with pre-existing checkpoint & + trajectory files. + """ + cwd = pathlib.Path("resume_files") + r = openfe.execute_DAG(protocol_dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) + + + + + +# @pytest.mark.integration # takes too long to be a slow test ~ 4 mins locally +# def test_openmm_run_engine( +# platform, +# get_available_openmm_platforms, +# benzene_modifications, +# tmpdir, +# ): +# cwd = pathlib.Path(str(tmpdir)) +# r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) +# +# assert r.ok() +# +# # Check outputs of solvent & vacuum results +# for phase in ["solvent", "vacuum"]: +# purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase] +# +# # get the path to the simulation unit shared dict +# for pur in purs: +# if "Simulation" in pur.name: +# sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" +# assert sim_shared.exists() +# assert pathlib.Path(sim_shared).is_dir() +# +# # check the analysis outputs +# for pur in purs: +# if "Analysis" not in pur.name: +# continue +# +# unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" +# assert unit_shared.exists() +# assert pathlib.Path(unit_shared).is_dir() +# +# # Does the checkpoint file exist? +# checkpoint = pur.outputs["checkpoint"] +# assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc" +# assert checkpoint.exists() +# +# # Does the trajectory file exist? +# nc = pur.outputs["trajectory"] +# assert nc == sim_shared / f"{pur.outputs['simtype']}.nc" +# assert nc.exists() +# +# # Test results methods that need files present +# results = protocol.gather([r]) +# states = results.get_replica_states() +# assert len(states.items()) == 2 +# assert len(states["solvent"]) == 1 +# assert states["solvent"][0].shape[1] == 20