From dd1d639752ff645905c1fec84e89c62090135e7b Mon Sep 17 00:00:00 2001 From: jthorton Date: Mon, 23 Sep 2024 13:28:00 +0100 Subject: [PATCH 1/3] allow a list of mappings --- feflow/protocols/nonequilibrium_cycling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index de6fc059..c34a7efb 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -1,6 +1,6 @@ # Adapted from perses: https://github.com/choderalab/perses/blob/protocol-neqcyc/perses/protocols/nonequilibrium_cycling.py -from typing import Optional, List, Dict, Any +from typing import Optional, Any, Union from collections.abc import Iterable from itertools import chain @@ -988,7 +988,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, extends: Optional[ProtocolDAGResult] = None, ) -> list[ProtocolUnit]: # Handle parameters @@ -997,6 +997,8 @@ def _create( if extends: raise NotImplementedError("Can't extend simulations yet") + mapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore + # inputs to `ProtocolUnit.__init__` should either be `Gufe` objects # or JSON-serializable objects num_cycles = self.settings.num_cycles From 82447f0b0b75780ea422b6ac0a52fedd948a3cfd Mon Sep 17 00:00:00 2001 From: jthorton Date: Mon, 23 Sep 2024 14:49:52 +0100 Subject: [PATCH 2/3] validate the mapping and systems, add test --- feflow/protocols/nonequilibrium_cycling.py | 9 +++++++++ feflow/tests/test_nonequilibrium_cycling.py | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index c34a7efb..2153d9c9 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -991,12 +991,21 @@ def _create( mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, extends: Optional[ProtocolDAGResult] = None, ) -> list[ProtocolUnit]: + from openfe.protocols.openmm_rfe.equil_rfe_methods import ( + _validate_alchemical_components, + ) + from openfe.protocols.openmm_utils import system_validation + # Handle parameters if mapping is None: raise ValueError("`mapping` is required for this Protocol") if extends: raise NotImplementedError("Can't extend simulations yet") + # Get alchemical components & validate them + mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + # raise an error if we have more than one mapping + _validate_alchemical_components(alchem_comps, mapping) mapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore # inputs to `ProtocolUnit.__init__` should either be `Gufe` objects diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 9154a356..25432e20 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -491,6 +491,27 @@ def test_failing_partial_charge_assign( execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) + def test_error_with_multiple_mappings( + self, + protocol_short, + benzene_vacuum_system, + toluene_vacuum_system, + mapping_benzene_toluene, + ): + """ + Make sure that when a list of mappings is passed that an error is raised. + """ + + with pytest.raises( + ValueError, match="A single LigandAtomMapping is expected for this Protocol" + ): + _ = protocol_short.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Test protocol", + mapping=[mapping_benzene_toluene, mapping_benzene_toluene], + ) + class TestSetupUnit: def test_setup_user_charges( From 55b876519d202cd35cf2d5c738b27b6320e60c23 Mon Sep 17 00:00:00 2001 From: jthorton Date: Tue, 24 Sep 2024 15:56:19 +0100 Subject: [PATCH 3/3] remove double mapping validation, fix toluene to toluene test --- feflow/protocols/nonequilibrium_cycling.py | 4 +--- feflow/tests/conftest.py | 4 ++-- feflow/tests/test_nonequilibrium_cycling.py | 20 ++++++++++++++++---- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 2153d9c9..7f63de80 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -996,12 +996,10 @@ def _create( ) from openfe.protocols.openmm_utils import system_validation - # Handle parameters - if mapping is None: - raise ValueError("`mapping` is required for this Protocol") if extends: raise NotImplementedError("Can't extend simulations yet") + # Do manual validation until it is part of the protocol # Get alchemical components & validate them + mapping alchem_comps = system_validation.get_alchemical_components(stateA, stateB) # raise an error if we have more than one mapping diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index 221eaeaa..fda42e02 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -221,8 +221,8 @@ def mapping_toluene_toluene(toluene): i: i for i in range(len(toluene.to_rdkit().GetAtoms())) } mapping_obj = LigandAtomMapping( - componentA=toluene, - componentB=toluene, + componentA=toluene.copy_with_replacements(name="toluene_a"), + componentB=toluene.copy_with_replacements(name="toluene_b"), componentA_to_componentB=mapping_toluene_to_toluene, ) return mapping_obj diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 25432e20..2499a625 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -303,7 +303,13 @@ def test_create_execute_gather( ], ) def test_create_execute_gather_toluene_to_toluene( - self, protocol, toluene_vacuum_system, mapping_toluene_toluene, tmpdir, request + self, + protocol, + toluene_vacuum_system, + mapping_toluene_toluene, + tmpdir, + request, + toluene, ): """ Perform 20 independent simulations of the NEQ cycling protocol for the toluene to toluene @@ -324,10 +330,16 @@ def test_create_execute_gather_toluene_to_toluene( import numpy as np protocol = request.getfixturevalue(protocol) - + # rename the components + toluene_state_a = toluene_vacuum_system.copy_with_replacements( + components={"ligand": toluene.copy_with_replacements(name="toluene_a")} + ) + toluene_state_b = toluene_vacuum_system.copy_with_replacements( + components={"ligand": toluene.copy_with_replacements(name="toluene_b")} + ) dag = protocol.create( - stateA=toluene_vacuum_system, - stateB=toluene_vacuum_system, + stateA=toluene_state_a, + stateB=toluene_state_b, name="Toluene vacuum transformation", mapping=mapping_toluene_toluene, )