Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 131 additions & 21 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from openmmtools.alchemy import (
AlchemicalRegion,
AbsoluteAlchemicalFactory,
AlchemicalState,
)
from typing import Optional
from openmm import app
Expand Down Expand Up @@ -83,6 +82,50 @@
logger = logging.getLogger(__name__)


class SingleRegionAlchemicalState(GlobalParameterState):
class _LambdaParameter(GlobalParameterState.GlobalParameter):
"""A global parameter in the interval [0, 1] with standard value 1."""

def __init__(self, parameter_name):
super().__init__(parameter_name, standard_value=1.0, validator=self.lambda_validator)

@staticmethod
def lambda_validator(self, instance, parameter_value):
if parameter_value is None:
return parameter_value
if not (0.0 <= parameter_value <= 1.0):
raise ValueError('{} must be between 0 and 1.'.format(self.parameter_name))
return float(parameter_value)

lambda_sterics_A = _LambdaParameter("lambda_sterics_A")
lambda_electrostatics_A = _LambdaParameter("lambda_electrostatics_A")
lambda_bonds_A = _LambdaParameter("lambda_bonds_A")
lambda_angles_A = _LambdaParameter("lambda_angles_A")
lambda_torsions_A = _LambdaParameter("lambda_torsions_A")


class TwoRegionAlchemicalState(SingleRegionAlchemicalState):
class _LambdaParameter(GlobalParameterState.GlobalParameter):
"""A global parameter in the interval [0, 1] with standard value 1."""

def __init__(self, parameter_name):
super().__init__(parameter_name, standard_value=1.0, validator=self.lambda_validator)

@staticmethod
def lambda_validator(self, instance, parameter_value):
if parameter_value is None:
return parameter_value
if not (0.0 <= parameter_value <= 1.0):
raise ValueError('{} must be between 0 and 1.'.format(self.parameter_name))
return float(parameter_value)

lambda_sterics_B = _LambdaParameter("lambda_sterics_B")
lambda_electrostatics_B = _LambdaParameter("lambda_electrostatics_B")
lambda_bonds_B = _LambdaParameter("lambda_bonds_B")
lambda_angles_B = _LambdaParameter("lambda_angles_B")
lambda_torsions_B = _LambdaParameter("lambda_torsions_B")


class BaseAbsoluteUnit(gufe.ProtocolUnit):
"""
Base class for ligand absolute free energy transformations.
Expand Down Expand Up @@ -598,18 +641,32 @@ def _get_lambda_schedule(

return lambdas

def _get_alchemical_ion(
self,
alchem_comps: dict[str, list[Component]],
comp_resids: dict[Component, npt.NDArray],
omm_topology: app.Topology,
positions: openmm.unit.Quantity,
settings: dict[str, SettingsBaseModel],
) -> int | None:
"""
Placeholder method to find alchemical ions if necessary.
"""
return None

def _add_restraints(
self,
system: openmm.System,
topology: GlobalParameterState,
topology: app.Topology,
positions: openmm.unit.Quantity,
alchem_comps: dict[str, list[Component]],
comp_resids: dict[Component, npt.NDArray],
settings: dict[str, SettingsBaseModel],
alchem_ion: int | None,
) -> tuple[
Optional[GlobalParameterState],
Optional[Quantity],
Optional[openmm.System],
openmm.System,
Optional[geometry.BaseRestraintGeometry],
]:
"""
Expand All @@ -623,7 +680,8 @@ def _get_alchemical_system(
system: openmm.System,
comp_resids: dict[Component, npt.NDArray],
alchem_comps: dict[str, list[Component]],
) -> tuple[AbsoluteAlchemicalFactory, openmm.System, list[int]]:
alchem_ion: int | None,
) -> tuple[AbsoluteAlchemicalFactory, openmm.System, dict[str, list[int]]]:
"""
Get an alchemically modified system and its associated factory

Expand All @@ -644,22 +702,36 @@ def _get_alchemical_system(
Factory for creating an alchemically modified system.
alchemical_system : openmm.System
Alchemically modified system
alchemical_indices : list[int]
alchemical_indices : dict[str, list[int]]
A list of atom indices for the alchemically modified
species in the system.

TODO
----
* Add support for all alchemical factory options
"""
alchemical_indices = self._get_alchemical_indices(topology, comp_resids, alchem_comps)
alchemical_indices = {
'A': self._get_alchemical_indices(topology, comp_resids, alchem_comps)
}

alchemical_region = AlchemicalRegion(
alchemical_atoms=alchemical_indices,
)
if alchem_ion is not None:
alchemical_indices['B'] = [alchem_ion]

alchemical_regions = []

for region in alchemical_indices:
alchemical_regions.append(
AlchemicalRegion(
alchemical_atoms=alchemical_indices[region],
name=region,
)
)

alchemical_factory = AbsoluteAlchemicalFactory()
alchemical_system = alchemical_factory.create_alchemical_system(system, alchemical_region)
alchemical_system = alchemical_factory.create_alchemical_system(
reference_system=system,
alchemical_regions=alchemical_regions
)

return alchemical_factory, alchemical_system, alchemical_indices

Expand All @@ -672,6 +744,7 @@ def _get_states(
lambdas: dict[str, list[float]],
solvent_comp: Optional[SolventComponent],
restraint_state: Optional[GlobalParameterState],
alchemical_indices: dict[str, list[int]],
) -> tuple[list[SamplerState], list[ThermodynamicState]]:
"""
Get a list of sampler and thermodynmic states from an
Expand All @@ -693,6 +766,9 @@ def _get_states(
The solvent component of the system, if there is one.
restraint_state : Optional[GlobalParameterState]
The restraint parameter control state, if there is one.
alchemical_indices : dict[str, list[int]]
Dictionary of the alchemical indices for each alchemical
region in the system.

Returns
-------
Expand All @@ -702,7 +778,17 @@ def _get_states(
A list of ThermodynamicState for each replica in the system.
"""
# Fetch an alchemical state
alchemical_state = AlchemicalState.from_system(alchemical_system)
if len(alchemical_indices.keys()) == 1:
alchemical_state = SingleRegionAlchemicalState.from_system(
alchemical_system
)
elif len(alchemical_indices.keys()) == 2:
alchemical_state = TwoRegionAlchemicalState.from_system(
alchemical_system
)
else:
errmsg = "more than two regions are not supported"
raise ValueError(errmsg)

# Set up the system constants
temperature = settings["thermo_settings"].temperature
Expand All @@ -714,7 +800,16 @@ def _get_states(
constants["pressure"] = ensure_quantity(pressure, "openmm")

# Get the thermodynamic parameter protocol
param_protocol = copy.deepcopy(lambdas)
param_protocol = {}

def _add_lambdas_to_protocol(protocol, lambdas, region_name):
protocol[f"lambda_electrostatics_{region_name}"] = lambdas["lambda_electrostatics"]
protocol[f"lambda_sterics_{region_name}"] = lambdas["lambda_sterics"]

param_protocol["lambda_restraints"] = lambdas["lambda_restraints"]
_add_lambdas_to_protocol(param_protocol, lambdas, "A")
if len(alchemical_indices.keys()) == 2:
_add_lambdas_to_protocol(param_protocol, lambdas, "B")

# Get the composable states
if restraint_state is not None:
Expand Down Expand Up @@ -1127,7 +1222,16 @@ def run(
# 5. Get lambdas
lambdas = self._get_lambda_schedule(settings)

# 6. Add restraints
# 6. Get alchemical ions
alchem_ion = self._get_alchemical_ion(
alchem_comps=alchem_comps,
comp_resids=comp_resids,
omm_topology=omm_topology,
positions=positions,
settings=settings,
)

# 7. Add restraints
# Note: when no restraint is applied, restrained_omm_system == omm_system
(
restraint_parameter_state,
Expand All @@ -1141,14 +1245,19 @@ def run(
alchem_comps,
comp_resids,
settings,
alchem_ion,
)

# 7. Get alchemical system
# 8. Get alchemical system
alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system(
omm_topology, restrained_omm_system, comp_resids, alchem_comps
omm_topology,
restrained_omm_system,
comp_resids,
alchem_comps,
alchem_ion,
)

# 8. Get compound and sampler states
# 9. Get compound and sampler states
sampler_states, cmp_states = self._get_states(
alchem_system,
positions,
Expand All @@ -1157,9 +1266,10 @@ def run(
lambdas,
solv_comp,
restraint_parameter_state,
alchem_indices,
)

# 9. Create the multistate reporter & create PDB
# 10. Create the multistate reporter & create PDB
reporter = self._get_reporter(
omm_topology,
positions,
Expand All @@ -1169,18 +1279,18 @@ def run(

# Wrap in try/finally to avoid memory leak issues
try:
# 10. Get context caches
# 11. Get context caches
energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches(
settings["forcefield_settings"], settings["engine_settings"]
)

# 11. Get integrator
# 12. Get integrator
integrator = self._get_integrator(
settings["integrator_settings"],
settings["simulation_settings"],
)

# 12. Get sampler
# 13. Get sampler
sampler = self._get_sampler(
integrator,
reporter,
Expand All @@ -1192,7 +1302,7 @@ def run(
sampler_ctx_cache,
)

# 13. Run simulation
# 14. Run simulation
unit_result_dict = self._run_simulation(
sampler, reporter, settings, standard_state_corr, dry
)
Expand Down
30 changes: 29 additions & 1 deletion openfe/protocols/openmm_afe/equil_afe_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
* Add support for restraints

"""
from openff.units import unit as offunit

from gufe.settings import (
SettingsBaseModel,
OpenMMSystemGeneratorFFSettings,
ThermoSettings,
)
from gufe.settings.typing import NanometerQuantity
from openfe.protocols.openmm_utils.omm_settings import (
MultiStateSimulationSettings,
BaseSolvationSettings,
Expand All @@ -35,6 +37,7 @@
from openfe.protocols.restraint_utils.settings import (
BaseRestraintSettings,
BoreschRestraintSettings,
SpringConstantLinearQuantity,
)

import numpy as np
Expand All @@ -49,6 +52,31 @@ class AlchemicalSettings(SettingsBaseModel):
"""


class ABFEAlchemicalSettings(AlchemicalSettings):
"""
Absolute binding free energy alchemical settings.
"""
explicit_charge_correction: bool = True
"""
Whether or not to use explicit charge correction using
a co-alchemical ion.
"""
alchemical_ion_min_distance: NanometerQuantity = 1.0 * offunit.nanometer
"""
The minimum distance to search for a co-alchemical ion.
"""
alchemical_ion_solvent_spring_constant: SpringConstantLinearQuantity = 1000.0 * offunit.kilojoule_per_mole / offunit.nm**2
"""
The spring constant holding the ion away from the alchemical solute
in the solvent leg.
"""


class ABFERestraintSettings(BoreschRestraintSettings):
host_restraint_ids: list[int] | None = None
guest_restraint_ids: list[int] | None = None


class LambdaSettings(SettingsBaseModel):
"""Lambda schedule settings.

Expand Down Expand Up @@ -322,7 +350,7 @@ def must_be_positive(cls, v):
"""Settings for solvating the system in the complex."""

# Alchemical settings
alchemical_settings: AlchemicalSettings
alchemical_settings: ABFEAlchemicalSettings
"""
Alchemical protocol settings.
"""
Expand Down
Loading
Loading