diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..d663fad --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[report] +exclude_lines = + @abstractmethod + @abc.abstractmethod \ No newline at end of file diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..1d728da --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,59 @@ +name: "CI" +on: + # on PR to main + pull_request: + branches: + - main + # on push to main + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l {0} + +jobs: + tests: + runs-on: ${{ matrix.os }} + name: "${{ matrix.os }} python-${{ matrix.python-version }}" + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: + - "3.10" + - "3.11" + - "3.12" + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v2 + with: + environment-name: stratocaster-test + init-shell: bash + cache-environment: true + # Since gufe the single dependency, install it without env file + create-args: >- + python=${{ matrix.python-version }} + gufe + + - name: Install stratocaster + # install test dependencies + run: python -m pip install -e ".[test]" + + - name: Environment information + run: | + micromamba info + micromamba list + + - name: Run tests + run: | + pytest -v src/stratocaster/tests/ diff --git a/.gitignore b/.gitignore index 46bbf13..1d40d00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ -/stratocaster.egg-info/ -/stratocaster/__pycache__/ +/src/stratocaster/stratocaster.egg-info/ +/src/stratocaster/__pycache__/ +docs/_build/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e69de29 diff --git a/README.md b/README.md index 441ab6a..65ad925 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # stratocaster +[![CI](https://github.com/OpenFreeEnergy/stratocaster/actions/workflows/tests.yaml/badge.svg)](https://github.com/OpenFreeEnergy/stratocaster/actions/workflows/tests.yaml) + A library for proposing a prioritization of Transformations within AlchemicalNetworks. ## Installation diff --git a/devtools/conda-envs/docs.yaml b/devtools/conda-envs/docs.yaml new file mode 100644 index 0000000..e604c69 --- /dev/null +++ b/devtools/conda-envs/docs.yaml @@ -0,0 +1,6 @@ +name: stratocaster-docs +channels: + - conda-forge + +dependencies: + - python>=3.12 diff --git a/devtools/conda-envs/test.yaml b/devtools/conda-envs/test.yaml new file mode 100644 index 0000000..c5ef553 --- /dev/null +++ b/devtools/conda-envs/test.yaml @@ -0,0 +1,8 @@ +name: stratocaster-test +channels: + - conda-forge + +dependencies: + - python>=3.10 + + - gufe>=1.2.0 diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml deleted file mode 100644 index 2e895a4..0000000 --- a/devtools/conda-envs/test.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: stratocaster-test -channels: - - conda-forge - -dependencies: - - python>=3.9 - - - gufe>=1.0.0 - - - pytest - - pytest-xdist - - pytest-cov - - coverage \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..9c16f1b --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,3 @@ +API Reference +============= + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..1e54a51 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,27 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'stratocaster' +copyright = '2024, Ian Kenney' +author = 'Ian Kenney' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [] + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'ofe_sphinx_theme' +html_static_path = ['_static'] diff --git a/docs/developer_guide.rst b/docs/developer_guide.rst new file mode 100644 index 0000000..cf18c40 --- /dev/null +++ b/docs/developer_guide.rst @@ -0,0 +1,2 @@ +Developer Guide +=============== diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000..2069331 --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,31 @@ +Getting Started +=============== + +This guide will help you quickly get started using stratocaster. + +1. Installation +~~~~~~~~~~~~~~~ + +For installation instructions, refer to the :ref:`installation page`. + +2. Verify the installation +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Verify the installation was successful in a Python interpreter + +.. code:: python + + import statocaster + print(stratocaster.__version__) + +3. Quick-start example +~~~~~~~~~~~~~~~~~~~~~~ + +TODO + +Other resources +~~~~~~~~~~~~~~~ + +- `Source code repository `_ +- `GitHub issue tracker `_ + diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..eba5348 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,17 @@ +stratocaster +============ + +The stratocaster library is complimentary to gufe and provides suggestions, via Strategies, for optimally executing Transformation Protocols defined in AlchemicalNetworks. + +This library includes a set of Strategy implementations as well as base classes to facilitate the creation of custom Strategy implementations. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + :hidden: + + installation + getting_started + user_guide + developer_guide + api diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000..24bfdca --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,28 @@ +.. _installation-label: + +Installation +============ + +The only requirement for installing statocaster is a working installation of gufe with a version 1.2.0 or higher. +For general use, we recommend installing from the conda-forge channel, which will also install gufe in the process. + +conda-forge channel +~~~~~~~~~~~~~~~~~~~ + +If you use conda, stratocaster can be installed through the conda-forge channel. + +.. code:: + + conda create -n statocaster-env + conda activate stratocaster-env + conda install -c conda-forge stratocaster + +Development version +~~~~~~~~~~~~~~~~~~~ + +If you want to install the latest development version of stratocaster, you can do so using pip, provided that you have a working installation of gufe (version >=1.2.0) in your environment. + +.. code:: + + pip install git+https://github.com/OpenFreeEnergy/stratocaster.git@main + diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/user_guide.rst b/docs/user_guide.rst new file mode 100644 index 0000000..415d843 --- /dev/null +++ b/docs/user_guide.rst @@ -0,0 +1,2 @@ +User guide +========== diff --git a/pyproject.toml b/pyproject.toml index 7a0f407..3163e41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,10 +24,16 @@ Issues = "https://github.com/OpenFreeEnergy/stratocaster/issues" [project.optional-dependencies] test = [ "pytest", + "pytest-cov", ] dev = [ "stratocaster[test]", "black", + "isort", +] +docs = [ + "sphinx", + "ofe-sphinx-theme @ git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme.git@main", ] [build-system] @@ -37,6 +43,9 @@ requires = [ ] build-backend = "setuptools.build_meta" +[tool.isort] +profile = "black" + [tool.versioningit] default-version = "1+unknown" diff --git a/src/stratocaster/base/__init__.py b/src/stratocaster/base/__init__.py index b0b88d2..dbf17b6 100644 --- a/src/stratocaster/base/__init__.py +++ b/src/stratocaster/base/__init__.py @@ -1,2 +1,2 @@ -from .strategy import Strategy, StrategyResult from .models import StrategySettings +from .strategy import Strategy, StrategyResult diff --git a/src/stratocaster/base/models.py b/src/stratocaster/base/models.py index 1718c93..e443b75 100644 --- a/src/stratocaster/base/models.py +++ b/src/stratocaster/base/models.py @@ -1,7 +1,6 @@ from gufe.settings.models import SettingsBaseModel +# TODO: docstrings class StrategySettings(SettingsBaseModel): - - def __init__(self): - normalize_weights: bool = True + pass diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index b64e533..3acf336 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -1,16 +1,17 @@ import abc -from typing import Self +from typing import TypeVar -from gufe.tokenization import GufeTokenizable -from gufe import AlchemicalNetwork -from gufe.protocols import ProtocolResult +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey, GufeTokenizable from .models import StrategySettings +TProtocolResult = TypeVar("TProtocolResult", bound=ProtocolResult) +# TODO: docstrings class StrategyResult(GufeTokenizable): - def __init__(self, weights): + def __init__(self, weights: dict[GufeKey, float | None]): self._weights = weights @classmethod @@ -20,17 +21,51 @@ def _defaults(cls): def _to_dict(self) -> dict: return {"weights": self._weights} + # TODO: Return type from typing.Self when Python 3.10 is no longer supported @classmethod - def _from_dict(cls, dct: dict) -> Self: + def _from_dict(cls, dct: dict): return cls(**dct) + @property + def weights(self) -> dict[GufeKey, float | None]: + return self._weights + + def resolve(self) -> dict[GufeKey, float | None]: + weights = self.weights + weight_sum = sum([weight for weight in weights.values() if weight is not None]) + modified_weights = { + key: weight / weight_sum + for key, weight in weights.items() + if weight is not None + } + weights.update(modified_weights) + return weights + # TODO: docstrings class Strategy(GufeTokenizable): """An object that proposes the relative urgency of computing transformations within an AlchemicalNetwork.""" + _settings_cls: type[StrategySettings] + def __init__(self, settings: StrategySettings): + + if not hasattr(self.__class__, "_settings_cls"): + raise NotImplementedError( + f"class `{self.__class__.__qualname__}` must implement the `_settings_cls` attribute." + ) + + if not isinstance(settings, self._settings_cls): + raise ValueError( + f"`{self.__class__.__qualname__}` expected a `{self._settings_cls.__qualname__}` instance" + ) + self._settings = settings + super().__init__() + + @property + def settings(self) -> StrategySettings: + return self._settings @classmethod def _defaults(cls): @@ -39,8 +74,9 @@ def _defaults(cls): def _to_dict(self) -> dict: return {"settings": self._settings} + # TODO: Return type from typing.Self when Python 3.10 is no longer supported @classmethod - def _from_dict(cls, dct: dict) -> Self: + def _from_dict(cls, dct: dict): return cls(**dct) @classmethod @@ -52,13 +88,13 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: list[ProtocolResult], + protocol_results: dict[GufeKey, TProtocolResult], ) -> StrategyResult: raise NotImplementedError def propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: list[ProtocolResult], + protocol_results: dict[GufeKey, TProtocolResult], ) -> StrategyResult: return self._propose(alchemical_network, protocol_results) diff --git a/src/stratocaster/strategies/__init__.py b/src/stratocaster/strategies/__init__.py new file mode 100644 index 0000000..592b5e5 --- /dev/null +++ b/src/stratocaster/strategies/__init__.py @@ -0,0 +1,3 @@ +from stratocaster.strategies.connectivity import ConnectivityStrategy + +__all__ = ["ConnectivityStrategy"] diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py new file mode 100644 index 0000000..92c4367 --- /dev/null +++ b/src/stratocaster/strategies/connectivity.py @@ -0,0 +1,144 @@ +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + +from stratocaster.base import Strategy, StrategyResult +from stratocaster.base.models import StrategySettings + +try: + from pydantic.v1 import Field, root_validator, validator +except ImportError: + from pydantic import ( + Field, + root_validator, + validator, + ) + +import pydantic + + +# TODO: docstrings +class ConnectivityStrategySettings(StrategySettings): + + decay_rate: float = Field( + default=0.5, description="decay rate of the exponential decay penalty factor" + ) + cutoff: float | None = Field( + default=None, + description="unnormalized weight cutoff used for termination condition", + ) + max_runs: int | None = Field( + default=None, + description="the upper limit of protocol DAG results needed before a transformation is no longer weighed", + ) + + @validator("cutoff") + def validate_cutoff(cls, value): + if value is not None: + if not (0 < value): + raise ValueError("`cutoff` must be greater than 0") + return value + + @validator("decay_rate") + def validate_decay_rate(cls, value): + if not (0 < value < 1): + raise ValueError("`decay_rate` must be between 0 and 1") + return value + + @validator("max_runs") + def validate_max_runs(cls, value): + if value is not None: + if not value >= 1: + raise ValueError("`max_runs` must be greater than or equal to 1") + return value + + @root_validator + def check_cutoff_or_max_runs(cls, values): + max_runs, cutoff = values.get("max_runs"), values.get("cutoff") + + if max_runs is None and cutoff is None: + raise ValueError("At least one of `max_runs` or `cutoff` must be set") + + return values + + +# TODO: docstrings +class ConnectivityStrategy(Strategy): + + _settings_cls = ConnectivityStrategySettings + + def _exponential_decay_scaling(self, number_of_results: int, decay_rate: float): + return decay_rate**number_of_results + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ) -> StrategyResult: + """Propose `Transformation` weight recommendations based on high connectivity nodes. + + Parameters + ---------- + alchemical_network: AlchemicalNetwork + protocol_results: dict[GufeKey, ProtocolResult] + A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` + and whose values are the `ProtocolResult`s for those `Transformation`s. + + Returns + ------- + StrategyResult + A `StrategyResult` containing the proposed `Transformation` weights. + """ + + settings = self.settings + + # keep the type checker happy + assert isinstance(settings, ConnectivityStrategySettings) + + alchemical_network_mdg = alchemical_network.graph + weights: dict[GufeKey, float | None] = {} + + for state_a, state_b in alchemical_network_mdg.edges(): + num_neighbors_a = alchemical_network_mdg.degree(state_a) + num_neighbors_b = alchemical_network_mdg.degree(state_b) + + # linter-satisfying assertion + assert isinstance(num_neighbors_a, int) and isinstance(num_neighbors_b, int) + + transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ + 0 + ]["object"].key + + match (protocol_results.get(transformation_key)): + case None: + transformation_n_protcol_dag_results = 0 + case pr: + assert isinstance(pr, ProtocolResult) + transformation_n_protcol_dag_results = pr.n_protocol_dag_results + + scaling_factor = self._exponential_decay_scaling( + transformation_n_protcol_dag_results, settings.decay_rate + ) + weight = scaling_factor * (num_neighbors_a + num_neighbors_b) / 2 + + match (settings.max_runs, settings.cutoff): + case (None, cutoff) if cutoff is not None: + if weight < cutoff: + weight = None + case (max_runs, None) if max_runs is not None: + if transformation_n_protcol_dag_results >= max_runs: + weight = None + case (max_runs, cutoff) if max_runs is not None and cutoff is not None: + if ( + weight < cutoff + or transformation_n_protcol_dag_results >= max_runs + ): + weight = None + + weights[transformation_key] = weight + + results = StrategyResult(weights=weights) + return results + + @classmethod + def _default_settings(cls) -> StrategySettings: + return ConnectivityStrategySettings(max_runs=3) diff --git a/src/stratocaster/tests/networks.py b/src/stratocaster/tests/networks.py new file mode 100644 index 0000000..a6aae9e --- /dev/null +++ b/src/stratocaster/tests/networks.py @@ -0,0 +1,117 @@ +""" +This file contains a modified version of code originally from the +gufe test module (conftest.py). The modifications were made to allow +easier construction of AlchemicalNetworks for development testing and +user examples. + +Original Commit SHA: 71a9c6610a9e13c8f7d588bd8309150557f104a5 +""" + +import importlib + +import gufe +from gufe.tests.test_protocol import DummyProtocol +from openff.units import unit +from rdkit import Chem + + +class BenzeneModifications: + + @staticmethod + def load_benzene_modifications(): + path = ( + importlib.resources.files("gufe.tests.data") / "benzene_modifications.sdf" + ) + supp = Chem.SDMolSupplier(str(path), removeHs=False) + return {m.GetProp("_Name"): m for m in list(supp)} + + _mod = load_benzene_modifications() + + def __class_getitem__(cls, key): + return gufe.SmallMoleculeComponent(cls._mod[key]) + + +def PDB_181L_path(): + path = importlib.resources.files("gufe.tests.data") / "181l.pdb" + return str(path) + + +def benzene_variants_star_map_transformations(): + + benzene = BenzeneModifications["benzene"] + + variants = tuple( + map( + lambda x: BenzeneModifications[x], + [ + "toluene", + "phenol", + "benzonitrile", + "anisole", + "benzaldehyde", + "styrene", + ], + ) + ) + + solv_comp = gufe.SolventComponent( + positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar + ) + prot_comp = gufe.ProteinComponent.from_pdb_file(PDB_181L_path()) + + # define the solvent chemical systems and transformations between + # benzene and the others + solvated_ligands = {} + solvated_ligand_transformations = {} + + solvated_ligands["benzene"] = gufe.ChemicalSystem( + {"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent" + ) + + for ligand in variants: + solvated_ligands[ligand.name] = gufe.ChemicalSystem( + {"solvent": solv_comp, "ligand": ligand}, name=f"{ligand.name}-solvnet" + ) + solvated_ligand_transformations[("benzene", ligand.name)] = gufe.Transformation( + solvated_ligands["benzene"], + solvated_ligands[ligand.name], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + # define the complex chemical systems and transformations between + # benzene and the others + solvated_complexes = {} + solvated_complex_transformations = {} + + solvated_complexes["benzene"] = gufe.ChemicalSystem( + {"protein": prot_comp, "solvent": solv_comp, "ligand": benzene}, + name="benzene-complex", + ) + + for ligand in variants: + solvated_complexes[ligand.name] = gufe.ChemicalSystem( + {"protein": prot_comp, "solvent": solv_comp, "ligand": ligand}, + name=f"{ligand.name}-complex", + ) + solvated_complex_transformations[("benzene", ligand.name)] = ( + gufe.Transformation( + solvated_complexes["benzene"], + solvated_complexes[ligand.name], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + ) + + return list(solvated_ligand_transformations.values()), list( + solvated_complex_transformations.values() + ) + + +def benzene_variants_star_map(): + solvated_ligand_transformations, solvated_complex_transformations = ( + benzene_variants_star_map_transformations() + ) + return gufe.AlchemicalNetwork( + solvated_ligand_transformations + solvated_complex_transformations + ) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py new file mode 100644 index 0000000..236c511 --- /dev/null +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -0,0 +1,225 @@ +import math +from random import randint, shuffle + +import pytest +from gufe import AlchemicalNetwork +from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult + +from stratocaster.base.models import StrategySettings +from stratocaster.base.strategy import StrategyResult +from stratocaster.strategies.connectivity import ( + ConnectivityStrategy, + ConnectivityStrategySettings, +) +from stratocaster.tests.networks import ( + benzene_variants_star_map as _benzene_variants_star_map, +) + + +@pytest.fixture(scope="module") +def benzene_variants_star_map(): + return _benzene_variants_star_map() + + +from gufe.tokenization import GufeKey + +SETTINGS_VALID = [(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)] + + +@pytest.mark.parametrize( + ["decay_rate", "cutoff", "max_runs", "raises"], + [ + (0, None, None, ValueError), + (1, None, None, ValueError), + (0.5, 0, None, ValueError), + (0.5, None, 0, ValueError), + ] + + [(*vals, None) for vals in SETTINGS_VALID], # include all valid settings +) +def test_connectivity_strategy_settings(decay_rate, cutoff, max_runs, raises): + + def instantiate_settings(): + ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs + ) + + if raises: + with pytest.raises(raises): + instantiate_settings() + else: + instantiate_settings() + + +@pytest.fixture +def default_strategy(): + _settings = ConnectivityStrategy._default_settings() + return ConnectivityStrategy(_settings) + + +def test_propose_no_results( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + proposal: StrategyResult = default_strategy.propose(benzene_variants_star_map, {}) + + assert all([weight == 3.5 for weight in proposal._weights.values()]) + assert 1 == sum( + weight for weight in proposal.resolve().values() if weight is not None + ) + + +def test_propose_previous_results( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=2, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + + results = default_strategy.propose(benzene_variants_star_map, result_data) + results_no_data = default_strategy.propose(benzene_variants_star_map, {}) + + # the raw weights should no longer be the same + assert results.weights != results_no_data.weights + # since each transformation had the same number of previous results, resolve + # should give back the same normalized weights + assert results.resolve() == results_no_data.resolve() + + +def test_propose_max_runs_termination( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + assert isinstance(default_strategy.settings, ConnectivityStrategySettings) + max_runs = default_strategy.settings.max_runs + assert isinstance(max_runs, int) + + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + + results = default_strategy.propose(benzene_variants_star_map, result_data) + + # since the default strategy has a max_runs of 3, we expect all Nones + assert not [weight for weight in results.resolve().values() if weight is not None] + + +def test_propose_cutoff_num_runs_predictioned_termination(benzene_variants_star_map): + """We can predict the number of runs needed to terminate with a given cutoff. + + Each edge in benzene_variants_star_map has a base weight of 3.5. + """ + + settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) + strategy = ConnectivityStrategy(settings) + + assert isinstance(settings.cutoff, float) + + num_runs = math.floor( + math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) + ) + + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + + results = strategy.propose(benzene_variants_star_map, result_data) + + assert not [weight for weight in results.weights.values() if weight is not None] + + +@pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], SETTINGS_VALID) +def test_simulated_termination( + default_strategy, benzene_variants_star_map, decay_rate, cutoff, max_runs +): + + settings = ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs + ) + default_strategy = ConnectivityStrategy(settings) + + def counts_to_result_data(counts_dict): + result_data = {} + for transformation_key, count in counts_dict.items(): + result = DummyProtocolResult( + n_protocol_dag_results=count, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + return result_data + + def shuffle_take_n(keys_list, n): + shuffle(keys_list) + return keys_list[:n] + + # initial transforms + transformation_counts = { + transformation.key: 0 for transformation in benzene_variants_star_map.edges + } + + max_iterations = 100 + current_iteration = 0 + while current_iteration <= max_iterations: + + if current_iteration == max_iterations: + raise RuntimeError( + f"Strategy did not terminate in {max_iterations} iterations " + ) + + result_data = counts_to_result_data(transformation_counts) + proposal = default_strategy.propose(benzene_variants_star_map, result_data) + + # get random transformations from those with a non-None weight + resolved_keys = shuffle_take_n( + [key for key, weight in proposal.resolve().items() if weight is not None], 5 + ) + + if resolved_keys: + # pretend we ran each of the randomly selected protocols + for key in resolved_keys: + transformation_counts[key] += 1 + # if we got an empty list back, there are not more protocols to run + else: + break + current_iteration += 1 + + +def test_deterministic( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + + settings = default_strategy.settings + assert isinstance(settings, ConnectivityStrategySettings) + + max_runs = settings.max_runs + assert isinstance(max_runs, int) + + def random_runs(): + """Generate random randomized inputs for propose.""" + return { + transformation.key: DummyProtocolResult( + n_protocol_dag_results=randint(0, max_runs), + info=f"key: {transformation.key}", + ) + for transformation in benzene_variants_star_map.edges + } + + for _ in range(10): + random_protocol_results = random_runs() + proposal = default_strategy.propose( + benzene_variants_star_map, protocol_results=random_protocol_results + ) + for _ in range(3): + _proposal = default_strategy.propose( + benzene_variants_star_map, protocol_results=random_protocol_results + ) + assert _proposal == proposal diff --git a/src/stratocaster/tests/test_strategy_abstraction.py b/src/stratocaster/tests/test_strategy_abstraction.py new file mode 100644 index 0000000..2d50150 --- /dev/null +++ b/src/stratocaster/tests/test_strategy_abstraction.py @@ -0,0 +1,62 @@ +import pytest +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + +from stratocaster.base import Strategy, StrategySettings +from stratocaster.base.strategy import StrategyResult + + +class StrategyASettings(StrategySettings): + pass + + +class StrategyBSettings(StrategySettings): + pass + + +class StrategyNoSettings(Strategy): + + @classmethod + def _default_settings(cls) -> StrategySettings: + return cls._settings_cls() + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ) -> StrategyResult: + return StrategyResult({}) + + +class StrategyA(StrategyNoSettings): + _settings_cls = StrategyASettings + + +class StrategyB(StrategyNoSettings): + _settings_cls = StrategyBSettings + + +@pytest.mark.parametrize( + ("strategy", "settings"), + ((StrategyA, StrategyBSettings), (StrategyB, StrategyASettings)), +) +def test_incorrect_strategy_settings_passed(strategy, settings): + with pytest.raises(ValueError): + strategy(settings()) + + +@pytest.mark.parametrize( + ("strategy", "settings"), + ((StrategyA, StrategyASettings), (StrategyB, StrategyBSettings)), +) +def test_correct_strategy_settings_passed(strategy, settings): + strat_settings = settings() + strat = strategy(strat_settings) + + assert strat._settings_cls == settings + + +def test_no_settings_implemented(): + + with pytest.raises(NotImplementedError): + StrategyNoSettings(StrategyASettings()) diff --git a/src/stratocaster/tests/test_strategy_base.py b/src/stratocaster/tests/test_strategy_base.py new file mode 100644 index 0000000..6b20058 --- /dev/null +++ b/src/stratocaster/tests/test_strategy_base.py @@ -0,0 +1,48 @@ +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + +from stratocaster.base import Strategy, StrategyResult, StrategySettings + + +class TestStrategyResult: + + result = StrategyResult( + { + GufeKey("MyTransformation-ABC123"): 1, + GufeKey("MyTransformation-321CBA"): None, + GufeKey("MyOtherTransformation-789xyz"): 10, + } + ) + + def test_dict_roundtrip(self): + assert StrategyResult.from_dict(self.result.to_dict()) == self.result + + +class DummyStrategySettings(StrategySettings): + pass + + +class DummyStrategy(Strategy): + + _settings_cls = DummyStrategySettings + + @classmethod + def _default_settings(cls) -> DummyStrategySettings: + return DummyStrategySettings() + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ): + assert alchemical_network, protocol_results + return StrategyResult({}) + + +class TestStrategy: + + strategy = DummyStrategy(DummyStrategySettings()) + + def test_dict_roundtrip(self): + strategy_dict_form = self.strategy.to_dict() + assert DummyStrategy.from_dict(strategy_dict_form) == self.strategy