Skip to content
Merged
4 changes: 2 additions & 2 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _pre_equilibrate(
Parameters
----------
system : openmm.System
An OpenMM System to equilibrate.
The OpenMM System to equilibrate.
topology : openmm.app.Topology
OpenMM Topology of the System.
positions : openmm.unit.Quantity
Expand Down Expand Up @@ -502,7 +502,7 @@ def _get_omm_objects(
topology : app.Topology
OpenMM Topology object describing the parameterized system.
system : openmm.System
An non-alchemical OpenMM System of the simulated system.
A non-alchemical OpenMM System of the simulated system.
positions : openmm.unit.Quantity
Positions of the system.
comp_resids : dict[Component, npt.NDArray]
Expand Down
21 changes: 16 additions & 5 deletions openfe/protocols/openmm_md/plain_md_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
from openff.units.openmm import from_openmm, to_openmm
import openmm.unit as omm_unit
from typing import Optional
from openmm import app
import pathlib
from typing import Any, Iterable
import uuid
import time
import numpy as np
import mdtraj
from mdtraj.reporters import XTCReporter
from openfe.utils import without_oechem_backend, log_system_probe
from gufe import (
settings, ChemicalSystem, SmallMoleculeComponent,
ProteinComponent, SolventComponent
settings,
ChemicalSystem,
SmallMoleculeComponent,
)
from openfe.protocols.openmm_utils.omm_settings import (
BasePartialChargeSettings,
Expand Down Expand Up @@ -684,8 +683,20 @@ def run(self, *, dry=False, verbose=True,
'nc': shared_basepath / output_settings.production_trajectory_filename,
'last_checkpoint': shared_basepath / output_settings.checkpoint_storage_filename,
}
if output_settings.equil_nvt_structure:
# The checkpoint file can not exist if frequency > sim length
if not output['last_checkpoint'].exists():
output['last_checkpoint'] = None

# The NVT PDB can be ommitted if we don't run the simulation
# Note: we could also just check the file exist
if (
output_settings.equil_nvt_structure
and sim_settings.equilibration_length_nvt is not None
):
output['nvt_equil_pdb'] = shared_basepath / output_settings.equil_nvt_structure
else:
output['nvt_equil_pdb'] = None

if output_settings.equil_npt_structure:
output['npt_equil_pdb'] = shared_basepath / output_settings.equil_npt_structure

Expand Down
9 changes: 9 additions & 0 deletions openfe/tests/protocols/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@
from importlib import resources
from rdkit import Chem
from rdkit.Geometry import Point3D
from openmm import Platform
import openfe
from openff.units import unit


@pytest.fixture
def available_platforms() -> set[str]:
return {

Check warning on line 15 in openfe/tests/protocols/conftest.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/conftest.py#L15

Added line #L15 was not covered by tests
Platform.getPlatform(i).getName()
for i in range(Platform.getNumPlatforms())
}


@pytest.fixture
def benzene_vacuum_system(benzene_modifications):
return openfe.ChemicalSystem(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,21 @@
from gufe.protocols import execute_DAG
import pytest
from openff.units import unit
from openmm import Platform
import os
import pathlib

import openfe
from openfe.protocols import openmm_afe


@pytest.fixture
def available_platforms() -> set[str]:
return {Platform.getPlatform(i).getName() for i in range(Platform.getNumPlatforms())}


@pytest.fixture
def set_openmm_threads_1():
# for vacuum sims, we want to limit threads to one
# this fixture sets OPENMM_CPU_THREADS='1' for a single test, then reverts to previously held value
previous: str | None = os.environ.get('OPENMM_CPU_THREADS')

try:
os.environ['OPENMM_CPU_THREADS'] = '1'
yield
finally:
if previous is None:
del os.environ['OPENMM_CPU_THREADS']
else:
os.environ['OPENMM_CPU_THREADS'] = previous


@pytest.mark.integration # takes too long to be a slow test ~ 4 mins locally
@pytest.mark.flaky(reruns=3) # pytest-rerunfailures; we can get bad minimisation
@pytest.mark.parametrize('platform', ['CPU', 'CUDA'])
def test_openmm_run_engine(platform,
available_platforms,
benzene_modifications,
set_openmm_threads_1, tmpdir):
def test_openmm_run_engine(
platform,
available_platforms,
benzene_modifications,
tmpdir
):
if platform not in available_platforms:
pytest.skip(f"OpenMM Platform: {platform} not available")

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,13 @@ def solvent_protocol_dag(benzene_system):
def test_unit_tagging(solvent_protocol_dag, tmpdir):
# test that executing the Units includes correct generation and repeat info
dag_units = solvent_protocol_dag.protocol_units
with mock.patch('openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run',
return_value={'nc': 'file.nc', 'last_checkpoint': 'chk.nc'}):
with mock.patch(
'openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run',
return_value={
'nc': 'simulation.xtc',
'last_checkpoint': 'checkpoint.chk'
}
):
results = []
for u in dag_units:
ret = u.execute(context=gufe.Context(tmpdir, tmpdir))
Expand All @@ -452,8 +457,13 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir):

def test_gather(solvent_protocol_dag, tmpdir):
# check .gather behaves as expected
with mock.patch('openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run',
return_value={'nc': 'file.nc', 'last_checkpoint': 'chk.nc'}):
with mock.patch(
'openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run',
return_value={
'nc': 'simulation.xtc',
'last_checkpoint': 'checkpoint.chk'
}
):
dagres = gufe.protocols.execute_DAG(solvent_protocol_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
Expand Down
144 changes: 144 additions & 0 deletions openfe/tests/protocols/openmm_md/test_plain_md_slow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import pathlib
import pytest
from openff.units import unit
from gufe.protocols import execute_DAG
from openfe.protocols import openmm_md


@pytest.mark.integration
@pytest.mark.parametrize('platform', ['CPU', 'CUDA'])
def test_vacuum_sim(
benzene_vacuum_system,
platform,
available_platforms,
tmpdir
):
if platform not in available_platforms:
pytest.skip(f"OpenMM Platform: {platform} is not available")

Check warning on line 19 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L18-L19

Added lines #L18 - L19 were not covered by tests

# Run a vacuum MD simulation and check what files we get.
settings = openmm_md.PlainMDProtocol.default_settings()
settings.simulation_settings.equilibration_length_nvt = None
settings.simulation_settings.equilibration_length = 10 * unit.picosecond
settings.simulation_settings.production_length = 20 * unit.picosecond
settings.output_settings.checkpoint_interval = 40 * unit.picosecond
settings.forcefield_settings.nonbonded_method = "nocutoff"
settings.engine_settings.compute_platform = platform

Check warning on line 28 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L22-L28

Added lines #L22 - L28 were not covered by tests

prot = openmm_md.PlainMDProtocol(settings)

Check warning on line 30 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L30

Added line #L30 was not covered by tests

dag = prot.create(

Check warning on line 32 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L32

Added line #L32 was not covered by tests
stateA=benzene_vacuum_system,
stateB=benzene_vacuum_system,
mapping=None,
)

workdir = pathlib.Path(str(tmpdir))

Check warning on line 38 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L38

Added line #L38 was not covered by tests

r = execute_DAG(

Check warning on line 40 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L40

Added line #L40 was not covered by tests
dag,
shared_basedir=workdir,
scratch_basedir=workdir,
keep_shared=True
)

assert r.ok()

Check warning on line 47 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L47

Added line #L47 was not covered by tests

assert len(r.protocol_unit_results) == 1

Check warning on line 49 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L49

Added line #L49 was not covered by tests

pur = r.protocol_unit_results[0]
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()

Check warning on line 54 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L51-L54

Added lines #L51 - L54 were not covered by tests

# check the files
files = [

Check warning on line 57 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L57

Added line #L57 was not covered by tests
"equil_npt.pdb",
"minimized.pdb",
"simulation.xtc",
"simulation.log",
"system.pdb"
]
for file in files:
assert (unit_shared / file).exists()

Check warning on line 65 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L64-L65

Added lines #L64 - L65 were not covered by tests

# NVT PDB should not exist
assert not (unit_shared / "equil_nvt.pdb").exists()
assert not (unit_shared / "checkpoint.chk").exists()

Check warning on line 69 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L68-L69

Added lines #L68 - L69 were not covered by tests

# check that the output file paths are correct
assert pur.outputs['system_pdb'] == unit_shared / "system.pdb"
assert pur.outputs['minimized_pdb'] == unit_shared / "minimized.pdb"
assert pur.outputs['nc'] == unit_shared / "simulation.xtc"
assert pur.outputs['last_checkpoint'] is None
assert pur.outputs['npt_equil_pdb'] == unit_shared / "equil_npt.pdb"
assert pur.outputs['nvt_equil_pdb'] is None

Check warning on line 77 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L72-L77

Added lines #L72 - L77 were not covered by tests


@pytest.mark.integration
@pytest.mark.parametrize('platform', ['CUDA'])
def test_complex_solvent_sim_gpu(
benzene_complex_system,
platform,
available_platforms,
tmpdir,
):
if platform not in available_platforms:
pytest.skip(f"OpenMM Platform: {platform} is not available")

Check warning on line 89 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L88-L89

Added lines #L88 - L89 were not covered by tests

# Run an MD simulation and check what files we get.
settings = openmm_md.PlainMDProtocol.default_settings()
settings.simulation_settings.equilibration_length_nvt = 50 * unit.picosecond
settings.simulation_settings.equilibration_length = 50 * unit.picosecond
settings.simulation_settings.production_length = 100 * unit.picosecond
settings.output_settings.checkpoint_interval = 10 * unit.picosecond
settings.engine_settings.compute_platform = platform

Check warning on line 97 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L92-L97

Added lines #L92 - L97 were not covered by tests

prot = openmm_md.PlainMDProtocol(settings)

Check warning on line 99 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L99

Added line #L99 was not covered by tests

dag = prot.create(

Check warning on line 101 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L101

Added line #L101 was not covered by tests
stateA=benzene_complex_system,
stateB=benzene_complex_system,
mapping=None,
)

workdir = pathlib.Path(str(tmpdir))

Check warning on line 107 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L107

Added line #L107 was not covered by tests

r = execute_DAG(

Check warning on line 109 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L109

Added line #L109 was not covered by tests
dag,
shared_basedir=workdir,
scratch_basedir=workdir,
keep_shared=True
)

assert r.ok()

Check warning on line 116 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L116

Added line #L116 was not covered by tests

assert len(r.protocol_unit_results) == 1

Check warning on line 118 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L118

Added line #L118 was not covered by tests

pur = r.protocol_unit_results[0]
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()

Check warning on line 123 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L120-L123

Added lines #L120 - L123 were not covered by tests

# check the files
files = [

Check warning on line 126 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L126

Added line #L126 was not covered by tests
"checkpoint.chk",
"equil_nvt.pdb",
"equil_npt.pdb",
"minimized.pdb",
"simulation.xtc",
"simulation.log",
"system.pdb"
]
for file in files:
assert (unit_shared / file).exists()

Check warning on line 136 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L135-L136

Added lines #L135 - L136 were not covered by tests

# check that the output file paths are correct
assert pur.outputs['system_pdb'] == unit_shared / "system.pdb"
assert pur.outputs['minimized_pdb'] == unit_shared / "minimized.pdb"
assert pur.outputs['nc'] == unit_shared / "simulation.xtc"
assert pur.outputs['last_checkpoint'] == unit_shared / "checkpoint.chk"
assert pur.outputs['nvt_equil_pdb'] == unit_shared / "equil_nvt.pdb"
assert pur.outputs['npt_equil_pdb'] == unit_shared / "equil_npt.pdb"

Check warning on line 144 in openfe/tests/protocols/openmm_md/test_plain_md_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/openmm_md/test_plain_md_slow.py#L139-L144

Added lines #L139 - L144 were not covered by tests
82 changes: 82 additions & 0 deletions openfe/tests/protocols/openmm_md/test_plain_md_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# This ccode is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import json
import pytest
import gufe
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
from openfe.protocols import openmm_md


@pytest.fixture
def protocol():
return openmm_md.PlainMDProtocol(
openmm_md.PlainMDProtocol.default_settings()
)


@pytest.fixture
def protocol_unit(protocol, benzene_system):
pus = protocol.create(
stateA=benzene_system,
stateB=benzene_system,
mapping=None,
)
return list(pus.protocol_units)[0]


@pytest.fixture
def protocol_result(md_json):
d = json.loads(md_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = gufe.ProtocolResult.from_dict(d['protocol_result'])
return pr


class TestPlainMDProtocol(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDProtocol
key = None
repr = "PlainMDProtocol-"

@pytest.fixture()
def instance(self, protocol):
return protocol

def test_repr(self, instance):
"""
Overwrites the base `test_repr` call to do a bit more.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)


class TestPlainMDProtocolUnit(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDProtocolUnit
repr = "PlainMDProtocolUnit("
key = None

@pytest.fixture
def instance(self, protocol_unit):
return protocol_unit

def test_repr(self, instance):
"""
Overwrites the base `test_repr` call to do a bit more.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)


class TestPlainMDProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDProtocolResult
key = None
repr = "PlainMDProtocolResult-"

@pytest.fixture()
def instance(self, protocol_result):
return protocol_result

def test_repr(self, instance):
"""
Overwrites the base `test_repr` call to do a bit more.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
Empty file.
Loading
Loading