diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index de6fc059..7f63de80 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -1,6 +1,6 @@ # Adapted from perses: https://github.com/choderalab/perses/blob/protocol-neqcyc/perses/protocols/nonequilibrium_cycling.py -from typing import Optional, List, Dict, Any +from typing import Optional, Any, Union from collections.abc import Iterable from itertools import chain @@ -988,15 +988,24 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, extends: Optional[ProtocolDAGResult] = None, ) -> list[ProtocolUnit]: - # Handle parameters - if mapping is None: - raise ValueError("`mapping` is required for this Protocol") + from openfe.protocols.openmm_rfe.equil_rfe_methods import ( + _validate_alchemical_components, + ) + from openfe.protocols.openmm_utils import system_validation + if extends: raise NotImplementedError("Can't extend simulations yet") + # Do manual validation until it is part of the protocol + # Get alchemical components & validate them + mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + # raise an error if we have more than one mapping + _validate_alchemical_components(alchem_comps, mapping) + mapping = mapping[0] if isinstance(mapping, list) else mapping # type: ignore + # inputs to `ProtocolUnit.__init__` should either be `Gufe` objects # or JSON-serializable objects num_cycles = self.settings.num_cycles diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index 221eaeaa..fda42e02 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -221,8 +221,8 @@ def mapping_toluene_toluene(toluene): i: i for i in range(len(toluene.to_rdkit().GetAtoms())) } mapping_obj = LigandAtomMapping( - componentA=toluene, - componentB=toluene, + componentA=toluene.copy_with_replacements(name="toluene_a"), + componentB=toluene.copy_with_replacements(name="toluene_b"), componentA_to_componentB=mapping_toluene_to_toluene, ) return mapping_obj diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index cbfaea8e..27b0d491 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -307,7 +307,13 @@ def test_create_execute_gather( ], ) def test_create_execute_gather_toluene_to_toluene( - self, protocol, toluene_vacuum_system, mapping_toluene_toluene, tmpdir, request + self, + protocol, + toluene_vacuum_system, + mapping_toluene_toluene, + tmpdir, + request, + toluene, ): """ Perform 20 independent simulations of the NEQ cycling protocol for the toluene to toluene @@ -328,10 +334,16 @@ def test_create_execute_gather_toluene_to_toluene( import numpy as np protocol = request.getfixturevalue(protocol) - + # rename the components + toluene_state_a = toluene_vacuum_system.copy_with_replacements( + components={"ligand": toluene.copy_with_replacements(name="toluene_a")} + ) + toluene_state_b = toluene_vacuum_system.copy_with_replacements( + components={"ligand": toluene.copy_with_replacements(name="toluene_b")} + ) dag = protocol.create( - stateA=toluene_vacuum_system, - stateB=toluene_vacuum_system, + stateA=toluene_state_a, + stateB=toluene_state_b, name="Toluene vacuum transformation", mapping=mapping_toluene_toluene, ) @@ -495,6 +507,27 @@ def test_failing_partial_charge_assign( execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) + def test_error_with_multiple_mappings( + self, + protocol_short, + benzene_vacuum_system, + toluene_vacuum_system, + mapping_benzene_toluene, + ): + """ + Make sure that when a list of mappings is passed that an error is raised. + """ + + with pytest.raises( + ValueError, match="A single LigandAtomMapping is expected for this Protocol" + ): + _ = protocol_short.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Test protocol", + mapping=[mapping_benzene_toluene, mapping_benzene_toluene], + ) + class TestSetupUnit: def test_setup_user_charges(