Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 146 additions & 64 deletions src/openfe/protocols/openmm_afe/base_afe_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,32 @@ def _execute(


class BaseAbsoluteMultiStateSimulationUnit(gufe.ProtocolUnit, AbsoluteUnitMixin):
@staticmethod
def _check_restart(output_settings: SettingsBaseModel, shared_path: pathlib.Path):
"""
Check if we are doing a restart.

Parameters
----------
output_settings : SettingsBaseModel
The simulation output settings
shared_path : pathlib.Path
The shared directory where we should be looking for existing files.

Notes
-----
For now this just checks if the netcdf files are present in the
shared directory but in the future this may expand depending on
how warehouse works.
"""
trajectory = shared_path / output_settings.output_filename
checkpoint = shared_path / output_settings.checkpoint_storage_filename

if trajectory.is_file() and checkpoint.is_file():
return True

return False

@abc.abstractmethod
def _get_components(
self,
Expand Down Expand Up @@ -1034,7 +1060,7 @@ def _get_reporter(
time_per_iteration=simulation_settings.time_per_iteration,
)

reporter = multistate.MultiStateReporter(
return multistate.MultiStateReporter(
storage=nc,
analysis_particle_indices=selection_indices,
checkpoint_interval=chk_intervals,
Expand All @@ -1043,8 +1069,6 @@ def _get_reporter(
velocity_interval=vel_interval,
)

return reporter

@staticmethod
def _get_sampler(
integrator: openmmtools.mcmc.LangevinDynamicsMove,
Expand All @@ -1054,6 +1078,7 @@ def _get_sampler(
compound_states: list[ThermodynamicState],
sampler_states: list[SamplerState],
platform: openmm.Platform,
restart: bool,
) -> multistate.MultiStateSampler:
"""
Get a sampler based on the equilibrium sampling method requested.
Expand All @@ -1074,51 +1099,93 @@ def _get_sampler(
A list of sampler states.
platform : openmm.Platform
The compute platform to use.
restart : bool
``True`` if we are doing a simulation restart.

Returns
-------
sampler : multistate.MultistateSampler
A sampler configured for the chosen sampling method.
"""
_SAMPLERS = {
"repex": multistate.ReplicaExchangeSampler,
"sams": multistate.SAMSSampler,
"independent": multistate.MultiStateSampler,
}

sampler_method = simulation_settings.sampler_method.lower()

# Get the real time analysis values to use
rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations(
simulation_settings=simulation_settings,
)
et_target_err = settings_validation.convert_target_error_from_kcal_per_mole_to_kT(
thermodynamic_settings.temperature,
simulation_settings.early_termination_target_error,
)

# Select the right sampler
# Note: doesn't need else, settings already validates choices
if simulation_settings.sampler_method.lower() == "repex":
sampler = multistate.ReplicaExchangeSampler(
mcmc_moves=integrator,
online_analysis_interval=rta_its,
online_analysis_target_error=et_target_err,
online_analysis_minimum_iterations=rta_min_its,
)
elif simulation_settings.sampler_method.lower() == "sams":
sampler = multistate.SAMSSampler(
mcmc_moves=integrator,
online_analysis_interval=rta_its,
online_analysis_minimum_iterations=rta_min_its,
flatness_criteria=simulation_settings.sams_flatness_criteria,
gamma0=simulation_settings.sams_gamma0,
)
elif simulation_settings.sampler_method.lower() == "independent":
sampler = multistate.MultiStateSampler(
mcmc_moves=integrator,
online_analysis_interval=rta_its,
online_analysis_target_error=et_target_err,
online_analysis_minimum_iterations=rta_min_its,
# Get the number of production iterations to run for
steps_per_iteration = integrator.n_steps
timestep = from_openmm(integrator.timestep)
number_of_iterations = int(
settings_validation.get_simsteps(
sim_length=simulation_settings.production_length,
timestep=timestep,
mc_steps=steps_per_iteration,
)
/ steps_per_iteration
)

sampler.create(
thermodynamic_states=compound_states,
sampler_states=sampler_states,
storage=reporter,
# convert early_termination_target_error from kcal/mol to kT
early_termination_target_error = (
settings_validation.convert_target_error_from_kcal_per_mole_to_kT(
thermodynamic_settings.temperature,
simulation_settings.early_termination_target_error,
)
)

sampler_kwargs = {
"mcmc_moves": integrator,
"online_analysis_interval": rta_its,
"online_analysis_target_error": early_termination_target_error,
"online_analysis_minimum_iterations": rta_min_its,
"number_of_iterations": number_of_iterations,
}

if sampler_method == "sams":
sampler_kwargs |= {
"flatness_criteria": simulation_settings.sams_flatness_criteria,
"gamma0": simulation_settings.sams_gamma0,
}

if sampler_method == "repex":
sampler_kwargs |= {
"replica_mixing_scheme": "swap-all",
}

if restart:
sampler = _SAMPLERS[sampler_method].from_storage(reporter)

# Add some tests here
sampler_system = sampler._thermodynamic_states[0].get_system(remove_thermostat=True)
system = compound_states[0].get_system(rermove_thermostat=True)

if (
(simulation_settings.n_replicas != sampler.n_states != sampler.n_replicas)
or (system.getNumForces() != sampler_system.getNumForces())
or (system.getNumParticles() != sampler_system.getNumParticles())
or (system.getNumConstraints() != sampler_system.getNumConstraints())
or (sampler.mcmc_moves[0].n_steps != steps_per_iteration)
or (sampler.mcmc_moves[0].timestep != integrator.timestep)
):
errmsg = "System in checkpoint does not match protocol system, cannot resume"
raise ValueError(errmsg)
else:
sampler = _SAMPLERS[sampler_method](**sampler_kwargs)

sampler.create(
thermodynamic_states=compound_states,
sampler_states=sampler_states,
storage=reporter,
)

# Get and set the context caches
sampler.energy_context_cache = openmmtools.cache.ContextCache(
capacity=None,
time_to_live=None,
Expand Down Expand Up @@ -1172,22 +1239,27 @@ def _run_simulation(
)

if not dry: # pragma: no-cover
# minimize
if self.verbose:
self.logger.info("minimizing systems")
# No production steps have been taken, so start from scratch
if sampler._iteration == 0:
# minimize
if self.verbose:
self.logger.info("minimizing systems")

sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps)
sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps)

# equilibrate
if self.verbose:
self.logger.info("equilibrating systems")
# equilibrate
if self.verbose:
self.logger.info("equilibrating systems")

sampler.equilibrate(int(equil_steps / mc_steps))
sampler.equilibrate(int(equil_steps / mc_steps))

# production
# St this point we are ready for production
if self.verbose:
self.logger.info("running production phase")
sampler.extend(int(prod_steps / mc_steps))

# We use `run` so that we're limited by the number of iterations
# we passed when we built the sampler.
sampler.run(n_iterations=int(prod_steps / mc_steps) - sampler._iteration)

if self.verbose:
self.logger.info("production phase complete")
Expand Down Expand Up @@ -1257,6 +1329,12 @@ def run(
# Get the settings
settings = self._get_settings()

# Check for a restart
self.restart = self._check_restart(
output_settings=settings["output_settings"],
shared_path=self.shared_basepath,
)

# Get the components
alchem_comps, solv_comp, prot_comp, small_mols = self._get_components()

Expand Down Expand Up @@ -1299,7 +1377,7 @@ def run(
output_settings=settings["output_settings"],
)

# Get sampler
# Get the sampler
sampler = self._get_sampler(
integrator=integrator,
reporter=reporter,
Expand All @@ -1308,9 +1386,10 @@ def run(
compound_states=cmp_states,
sampler_states=sampler_states,
platform=platform,
restart=self.restart,
)

# Run simulation
# Run the simulation
self._run_simulation(
sampler=sampler,
reporter=reporter,
Expand All @@ -1319,24 +1398,27 @@ def run(
)

finally:
# close reporter when you're done to prevent file handle clashes
reporter.close()

# clear GPU context
# Note: use cache.empty() when openmmtools #690 is resolved
for context in list(sampler.energy_context_cache._lru._data.keys()):
del sampler.energy_context_cache._lru._data[context]
for context in list(sampler.sampler_context_cache._lru._data.keys()):
del sampler.sampler_context_cache._lru._data[context]
# cautiously clear out the global context cache too
for context in list(openmmtools.cache.global_context_cache._lru._data.keys()):
del openmmtools.cache.global_context_cache._lru._data[context]

del sampler.sampler_context_cache, sampler.energy_context_cache

# Keep these around in a dry run so we can inspect things
if not dry:
del integrator, sampler
# Order is reporter, sampler, and then integrator
try:
reporter.close() # close to prevent file handle clashes

# clear GPU context
# Note: use cache.empty() when openmmtools #690 is resolved
for context in list(sampler.energy_context_cache._lru._data.keys()):
del sampler.energy_context_cache._lru._data[context]
for context in list(sampler.sampler_context_cache._lru._data.keys()):
del sampler.sampler_context_cache._lru._data[context]
# cautiously clear out the global context cache too
for context in list(openmmtools.cache.global_context_cache._lru._data.keys()):
del openmmtools.cache.global_context_cache._lru._data[context]

del sampler.sampler_context_cache, sampler.energy_context_cache

# Keep these around in a dry run so we can inspect things
if not dry:
del integrator, sampler
except UnboundLocalError:
pass

if not dry:
nc = self.shared_basepath / settings["output_settings"].output_filename
Expand Down
Loading
Loading