diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index 918e1ba7f..ab30fc206 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -112,7 +112,10 @@ def _from_dict(cls, dct: dict): @classmethod @abc.abstractmethod - def _default_settings(cls) -> Settings: + def _default_settings(cls, *, + stateA: Optional[ChemicalSystem] = None, + stateB: Optional[ChemicalSystem] = None, + ) -> Settings: """Method to override in custom `Protocol` subclasses. Gives a usable instance of ``Settings`` that function as @@ -122,14 +125,23 @@ def _default_settings(cls) -> Settings: raise NotImplementedError() @classmethod - def default_settings(cls) -> Settings: + def default_settings(cls, *, + stateA: Optional[ChemicalSystem] = None, + stateB: Optional[ChemicalSystem] = None, + ) -> Settings: """Get the default settings for this `Protocol`. These can be modified and passed in as the `settings` for a new `Protocol` instance. + Parameters + ---------- + stateA, stateB: ChemicalSystem, optional + details of the chemistry that the Protocol will operate on. + Depending on implementation, this can be used to fine tune the + default settings which are used. """ - return cls._default_settings() + return cls._default_settings(stateA=stateA, stateB=stateB) @abc.abstractmethod def _create( diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index 32d3ff885..b4be78678 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -97,11 +97,24 @@ class DummyProtocol(Protocol): result_cls = DummyProtocolResult @classmethod - def _default_settings(cls): + def _default_settings(cls, *, + stateA=None, + stateB=None): + repeats = 21 + + if stateA and stateB: + # we got a hint of the chemistry involved + # so inspect and do something clever here + # for example if solvent is water we do repeats=21, + # otherwise repeats=42 + sc: gufe.SolventComponent = stateA.components.get('solvent', None) + if sc is not None: + repeats = 21 if sc.smiles == 'O' else 42 + return DummySpecificSettings( thermo_settings=settings.ThermoSettings(temperature=298 * unit.kelvin), forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), - n_repeats=21, + n_repeats=repeats, ) @classmethod @@ -505,7 +518,7 @@ def _defaults(cls): return {} @classmethod - def _default_settings(cls): + def _default_settings(cls, stateA=None, stateB=None): return settings.Settings.get_defaults() def _create( @@ -738,3 +751,39 @@ def test_settings_readonly(): p.settings.thermo_settings.temperature = 400.0 * unit.kelvin assert p.settings.thermo_settings.temperature == before + + +def test_customised_default_settings(): + # check that default_settings hook is able to inspect the chemistry to + # offer different defaults + cs1 = ChemicalSystem( + components={ + 'solvent': gufe.SolventComponent(smiles='CC') + } + ) + cs2 = ChemicalSystem( + components={ + 'solvent': gufe.SolventComponent(smiles='CC') + } + ) + + ds = DummyProtocol.default_settings(stateA=cs1, stateB=cs2) + + assert ds.n_repeats == 42 + + +def test_customised_default_settings_defaults(): + cs1 = ChemicalSystem( + components={ + 'solvent': gufe.SolventComponent(smiles='O') + } + ) + cs2 = ChemicalSystem( + components={ + 'solvent': gufe.SolventComponent(smiles='O') + } + ) + + ds = DummyProtocol.default_settings(stateA=cs1, stateB=cs2) + + assert ds.n_repeats == 21 diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py index 9454eda01..bbbaa0dcc 100644 --- a/gufe/tests/test_protocoldag.py +++ b/gufe/tests/test_protocoldag.py @@ -41,7 +41,7 @@ class WriterProtocol(gufe.Protocol): result_cls = WriterProtocolResult @classmethod - def _default_settings(cls): + def _default_settings(cls, stateA=None, stateB=None): return WriterSettings( thermo_settings=gufe.settings.ThermoSettings(temperature=298 * unit.kelvin), forcefield_settings=gufe.settings.OpenMMSystemGeneratorFFSettings(),