Skip to content
36 changes: 30 additions & 6 deletions feflow/protocols/nonequilibrium_cycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,32 @@ def _execute(self, ctx, *, protocol, state_a, state_b, mapping, **inputs):
system = hybrid_factory.hybrid_system
positions = hybrid_factory.hybrid_positions

# Set up integrator
temperature = to_openmm(thermodynamic_settings.temperature)
# serialize system
system_outfile = ctx.shared / "system.xml.bz2"
serialize(system, system_outfile)

# Serialize positions
positions_outfile = ctx.shared / "positions.npy"
np.save(positions_outfile, positions)

# Serialize HTF
htf_outfile = ctx.shared / "hybrid_topology_factory.pickle"
# Serialize HTF, system, state and integrator
with open(htf_outfile, "wb") as htf_file:
pickle.dump(hybrid_factory, htf_file)

return {
"system": system_outfile,
"positions": positions_outfile,
"phase": phase,
"initial_atom_indices": hybrid_factory.initial_atom_indices,
"final_atom_indices": hybrid_factory.final_atom_indices,
"topology_path": htf_outfile,
}


class IntegratorSetupUnit(ProtocolUnit):
def _execute(ctx: Context, setup, **inputs) -> dict[str, Any]:
integrator_settings = settings.integrator_settings
integrator = PeriodicNonequilibriumIntegrator(
alchemical_functions=settings.lambda_functions,
Expand All @@ -321,11 +345,14 @@ def _execute(self, ctx, *, protocol, state_a, state_b, mapping, **inputs):
temperature=temperature,
)

# TODO: Make sure we load the outputs from setup unit to meet the needs of this unit

# Set up context
platform = get_openmm_platform(settings.engine_settings.compute_platform)
context = openmm.Context(system, integrator, platform)
context.setPeriodicBoxVectors(*system.getDefaultPeriodicBoxVectors())
context.setPositions(positions)
serialize(integrator_, integrator_outfile)

try:
# Minimize
Expand Down Expand Up @@ -363,15 +390,12 @@ def _execute(self, ctx, *, protocol, state_a, state_b, mapping, **inputs):

serialize(system_, system_outfile)
serialize(state_, state_outfile)
serialize(integrator_, integrator_outfile)

finally:
# Explicit cleanup for GPU resources
del context, integrator
del context

return {
"system": system_outfile,
"state": state_outfile,
"integrator": integrator_outfile,
"initial_atom_indices": hybrid_factory.initial_atom_indices,
"final_atom_indices": hybrid_factory.final_atom_indices,
Expand Down
169 changes: 169 additions & 0 deletions feflow/protocols/nonequilibrium_switching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import mdtraj


class BaseSwitchingUnit(ProtocolUnit):
"""
Monolithic unit for the cycle part of the simulation.
It runs a number of NEq cycles from the outputs of a setup unit and stores the work computed in
numpy-formatted files, to be analyzed by a result unit.
"""

@staticmethod
def extract_positions(context, initial_atom_indices, final_atom_indices):
"""
Extract positions from initial and final systems based from the hybrid topology.

Parameters
----------
context: openmm.Context
Current simulation context where from extract positions.
hybrid_topology_factory: HybridTopologyFactory
Hybrid topology factory where to extract positions and mapping information

Returns
-------

Notes
-----
It achieves this by taking the positions and indices from the initial and final states of
the transformation, and computing the overlap of these with the indices of the complete
hybrid topology, filtered by some mdtraj selection expression.

1. Get positions from context
2. Get topology from HTF (already mdtraj topology)
3. Merge that information into mdtraj.Trajectory
4. Filter positions for initial/final according to selection string
"""
import numpy as np

# Get positions from current openmm context
positions = context.getState(getPositions=True).getPositions(asNumpy=True)

# Get indices for initial and final topologies in hybrid topology
initial_indices = np.asarray(initial_atom_indices)
final_indices = np.asarray(final_atom_indices)

initial_positions = positions[initial_indices, :]
final_positions = positions[final_indices, :]

return initial_positions, final_positions

def _execute(self, ctx, *, protocol, md_unit, index, **inputs):
"""
Execute the simulation part of the Nonequilibrium switching protocol using GUFE objects.

Parameters
----------
ctx : gufe.protocols.protocolunit.Context
The gufe context for the unit.
protocol : gufe.protocols.Protocol
The Protocol used to create this Unit. Contains key information
such as the settings.
md_unit : gufe.protocols.ProtocolUnit
The SetupUnit
index: int
TODO: Index for the snapshot to use as input

Returns
-------
dict : dict[str, str]
Dictionary with paths to work arrays, both forward and reverse, and trajectory coordinates for systems
A and B.
"""
import openmm
from openmmtools.integrators import PeriodicNonequilibriumIntegrator

# Setting up logging to file in shared filesystem
file_logger = logging.getLogger("neq-cycling")
output_log_path = ctx.shared / "feflow-neq-cycling.log"
file_handler = logging.FileHandler(output_log_path, mode="w")
file_handler.setLevel(logging.DEBUG) # TODO: Set to INFO in production
log_formatter = logging.Formatter(
fmt="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(log_formatter)
file_logger.addHandler(file_handler)

system = deserialize(md_unit.inputs["setup"].outputs["system"])
state = deserialize(md_unit.inputs["setup"].outputs["state"])
integrator = deserialize(md_unit.inputs["setup"].outputs["integrator"])

PeriodicNonequilibriumIntegrator.restore_interface(integrator)

# Get atom indices for either end of the hybrid topology
initial_atom_indices = setup.outputs["initial_atom_indices"]
final_atom_indices = setup.outputs["final_atom_indices"]

# Extract settings from the Protocol
settings = protocol.settings

# Load positions from snapshots
xtc_file = md_unit.outputs["production_trajectory"]
md_traj_ob = mdtraj.load_frame(xtc_file, index=index)
input_positions = md_traj_ob.openmm_positions(0)
# Set up context
platform = get_openmm_platform(settings.engine_settings.compute_platform)
context = openmm.Context(system, integrator, platform)
context.setState(state)
# TODO: This is kinda ugly, is there a better way to set positions?
context.setPositions(input_positions)

# Setting velocities to temperatures
thermodynamic_settings = settings.thermo_settings
temperature = to_openmm(thermodynamic_settings.temperature)
context.setVelocitiesToTemperature(temperature)

# Extract settings used below
neq_steps = settings.integrator_settings.nonequilibrium_steps
traj_save_frequency = settings.traj_save_frequency
work_save_frequency = (
settings.work_save_frequency
) # Note: this is divisor of traj save freq.
selection_expression = settings.atom_selection_expression

try:
# Coarse number of steps -- each coarse consists of work_save_frequency steps
coarse_neq_steps = int(
neq_steps / work_save_frequency
) # Note: neq_steps is multiple of work save steps

# TODO: Also get the GPU information (plain try-except with nvidia-smi)


integrator.step(NSTEPS)




# Equilibrium (lambda = 0)
# start timer
start_time = time.perf_counter()
# Run neq
# Forward (0 -> 1)
# Initialize works with current value
forward_works = []
for fwd_step in range(coarse_neq_steps):
integrator.step(work_save_frequency)
forward_works.append(integrator.get_protocol_work(dimensionless=True))
if fwd_step % traj_save_frequency == 0:
initial_positions, final_positions = self.extract_positions(
context, initial_atom_indices, final_atom_indices
)
forward_neq_initial.append(initial_positions)
forward_neq_final.append(final_positions)
# Make sure trajectories are stored at the end of the neq loop
initial_positions, final_positions = self.extract_positions(
context, initial_atom_indices, final_atom_indices
)
forward_neq_initial.append(initial_positions)
forward_neq_final.append(final_positions)

neq_forward_time = time.perf_counter()
neq_forward_walltime = datetime.timedelta(
seconds=neq_forward_time - eq_forward_time
)
file_logger.info(
f"replicate_{self.name} Forward nonequilibrium time (lambda 0 -> 1): {neq_forward_walltime}"
)

# TODO: We should return the work in one direction
60 changes: 37 additions & 23 deletions feflow/settings/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from typing import Annotated, TypeAlias

from pydantic.v1 import validator

from openff.units import unit
from gufe.settings import SettingsBaseModel
from gufe.settings.typing import GufeQuantity, specify_quantity_units
from pydantic import field_validator, ConfigDict

FemtosecondQuantity: TypeAlias = Annotated[
GufeQuantity, specify_quantity_units("femtoseconds")
Expand All @@ -22,15 +21,45 @@
]


class PeriodicNonequilibriumIntegratorSettings(SettingsBaseModel):
"""Settings for the PeriodicNonequilibriumIntegrator"""
class BaseNonequilibriumIntegrator(SettingsBaseModel):
"""Base class for nonequilibrium integrator settings"""

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

timestep: FemtosecondQuantity = 4 * unit.femtoseconds
"""Size of the simulation timestep. Default 4 fs."""
splitting: str = "V R H O R V"

# TODO: This validator is used in other settings, better create a new Type
@field_validator("timestep")
@classmethod
def must_be_positive(cls, v):
if v <= 0:
errmsg = f"timestep must be positive, received {v}."
raise ValueError(errmsg)
return v

# TODO: This validator is used in other settings, better create a new Type
@field_validator("timestep")
@classmethod
def is_time(cls, v):
# these are time units, not simulation steps
if not v.is_compatible_with(unit.picosecond):
raise ValueError("timestep must be in time units " "(i.e. picoseconds)")
return v


class AlchemicalNonequilibriumIntegratorSettings(BaseNonequilibriumIntegrator):
"""Settings for the AlchemicalNonequilibriumIntegrator used for switching"""

...


class PeriodicNonequilibriumIntegratorSettings(BaseNonequilibriumIntegrator):
"""Settings for the PeriodicNonequilibriumIntegrator"""

model_config = ConfigDict(arbitrary_types_allowed=True)

"""Operator splitting"""
equilibrium_steps: int = 12500
"""Number of steps for the equilibrium parts of the cycle. Default 12500"""
Expand All @@ -48,23 +77,8 @@ class Config:
"""

# TODO: This validator is used in other settings, better create a new Type
@validator("timestep")
def must_be_positive(cls, v):
if v <= 0:
errmsg = f"timestep must be positive, received {v}."
raise ValueError(errmsg)
return v

# TODO: This validator is used in other settings, better create a new Type
@validator("timestep")
def is_time(cls, v):
# these are time units, not simulation steps
if not v.is_compatible_with(unit.picosecond):
raise ValueError("timestep must be in time units " "(i.e. picoseconds)")
return v

# TODO: This validator is used in other settings, better create a new Type
@validator("equilibrium_steps", "nonequilibrium_steps")
@field_validator("equilibrium_steps", "nonequilibrium_steps")
@classmethod
def must_be_positive_or_zero(cls, v):
if v < 0:
errmsg = (
Expand Down
Loading