From 1bab1276b8b13e19ccb800e6de5d794b3b922eb1 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:18:10 -0700 Subject: [PATCH 01/15] feat: add warehouse primitives for handling protocol units --- src/openfe/storage/warehouse.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index c10baa12c..e0f2c771a 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -5,6 +5,8 @@ import re from typing import Literal, TypedDict +from gufe.protocols.protocoldag import ProtocolDAG +from gufe.protocols.protocolunit import ProtocolUnit from gufe.storage.externalresource import ExternalStorage, FileStorage from gufe.tokenization import ( JSON_HANDLER, @@ -35,6 +37,8 @@ class WarehouseStores(TypedDict): setup: ExternalStorage result: ExternalStorage + shared: ExternalStorage + tasks: ExternalStorage class WarehouseBaseClass: @@ -83,6 +87,12 @@ def delete(self, store_name: Literal["setup", "result"], location: str): store: ExternalStorage = self.stores[store_name] store.delete(location) + def store_task(self, obj: ProtocolUnit): + self._store_gufe_tokenizable("tasks", obj) + + def load_task(self, obj: GufeKey): + self._load_gufe_tokenizable(obj) + def store_setup_tokenizable(self, obj: GufeTokenizable): """Store a GufeTokenizable object in the setup store. @@ -134,7 +144,7 @@ def load_result_tokenizable(self, obj: GufeKey) -> GufeTokenizable: return self._load_gufe_tokenizable(gufe_key=obj) def exists(self, key: GufeKey) -> bool: - """Check if an object with the given key exists in any store. + """Check if an object with the given key exists in any store that holds tokenizables. Parameters ---------- @@ -171,7 +181,12 @@ def _get_store_for_key(self, key: GufeKey) -> ExternalStorage: return self.stores[name] raise ValueError(f"GufeKey {key} is not stored") - def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: GufeTokenizable): + def _store_gufe_tokenizable( + self, + store_name: Literal["setup", "result", "tasks"], + obj: GufeTokenizable, + name: str | None = None, + ): """Store a GufeTokenizable object with deduplication. Parameters @@ -197,7 +212,10 @@ def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: G data = json.dumps(keyed_dict, cls=JSON_HANDLER.encoder, sort_keys=True).encode( "utf-8" ) - target.store_bytes(gufe_key, data) + if name: + target.store_bytes(name, data) + else: + target.store_bytes(gufe_key, data) def _load_gufe_tokenizable(self, gufe_key: GufeKey) -> GufeTokenizable: """Load a deduplicated object from a GufeKey. @@ -315,5 +333,9 @@ class FileSystemWarehouse(WarehouseBaseClass): def __init__(self, root_dir: str = "warehouse"): setup_store = FileStorage(f"{root_dir}/setup") result_store = FileStorage(f"{root_dir}/result") - stores = WarehouseStores(setup=setup_store, result=result_store) + shared_store = FileStorage(f"{root_dir}/shared") + tasks_store = FileStorage(f"{root_dir}/tasks") + stores = WarehouseStores( + setup=setup_store, result=result_store, shared=shared_store, tasks=tasks_store + ) super().__init__(stores) From 3146240f79c78209b05346a7401489594b019995 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:18:40 -0700 Subject: [PATCH 02/15] feat: inital worker for exorcist --- environment.yml | 1 + src/openfe/orchestration/__init__.py | 57 ++++++++++++++++++++++ src/openfe/orchestration/exorcist_utils.py | 53 ++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 src/openfe/orchestration/exorcist_utils.py diff --git a/environment.yml b/environment.yml index 0d563c254..716bc0f05 100644 --- a/environment.yml +++ b/environment.yml @@ -53,3 +53,4 @@ dependencies: - pip: - git+https://github.com/OpenFreeEnergy/gufe@main - git+https://github.com/fatiando/pooch@main # related to https://github.com/fatiando/pooch/issues/502 + - git+https://github.com/OpenFreeEnergy/exorcist@main diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index e69de29bb..d8e31db06 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from pathlib import Path + +from exorcist.taskdb import TaskStatusDB +from gufe.protocols.protocoldag import _pu_to_pur +from gufe.protocols.protocolunit import ( + Context, + ProtocolUnit, + ProtocolUnitFailure, + ProtocolUnitResult, +) +from gufe.storage.externalresource.filestorage import FileStorage +from gufe.tokenization import GufeKey + +from openfe.storage.warehouse import FileSystemWarehouse + +from .exorcist_utils import ( + alchemical_network_to_task_graph, + build_task_db_from_alchemical_network, +) + + +@dataclass +class Worker: + warehouse: FileSystemWarehouse + + def _get_task(self) -> ProtocolUnit: + # Right now, we are just going to assume it exists in the warehouse folder + location = Path("./warehouse/tasks.db") + + db: TaskStatusDB = TaskStatusDB.from_filename(location) + # The format for the taskid is going to "Transformation-:Unit" + taskid = db.check_out_task() + # Load the unit from warehouse and return + unit = taskid.split(":") + + return self.warehouse.load_task(unit) + + def execute_unit(self, scratch: Path): + # 1. Get task/unit + unit = self._get_task() + # 2. Constrcut the context + # NOTE: On changes to context, this can easily be replaced with external storage objects + # However, to satisfy the current work, we will use this implementation where we + # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. + shared_store: FileStorage = self.warehouse.stores["shared"] + shared_root_dir = shared_store.root_dir + ctx = Context(scratch, shared=shared_root_dir) + results: dict[GufeKey, ProtocolUnitResult] = {} + inputs = _pu_to_pur(unit.inputs, results) + # 3. Execute unit + result = unit.execute(context=ctx, **inputs) + # if not result.ok(): + # Increment attempt in taskdb + # 4. output result to warehouse + # TODO: we may need to end up handling namespacing on the warehouse side for tokenizables + self.warehouse.store_result_tokenizable(result) diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py new file mode 100644 index 000000000..fc30c8c11 --- /dev/null +++ b/src/openfe/orchestration/exorcist_utils.py @@ -0,0 +1,53 @@ +"""Utilities for building Exorcist task graphs and task databases.""" + +from pathlib import Path + +import exorcist +import networkx as nx +from gufe import AlchemicalNetwork + +from openfe.storage.warehouse import WarehouseBaseClass + + +def alchemical_network_to_task_graph( + alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass +) -> nx.DiGraph: + """Build a global task DAG from an AlchemicalNetwork.""" + + global_dag = nx.DiGraph() + for transformation in alchemical_network.edges: + dag = transformation.create() + for unit in dag.protocol_units: + node_id = f"{transformation.name}-{transformation.key}:{unit.name}-{unit.key}" + global_dag.add_node( + node_id, + label=f"{transformation.name}\n{unit.name}", + transformation_key=str(transformation.key), + protocol_unit_key=str(unit.key), + ) + warehouse.store_task(unit) + for u, v in dag.graph.edges: + u_id = f"{transformation.key}:{u.key}" + v_id = f"{transformation.key}:{v.key}" + global_dag.add_edge(u_id, v_id) + + if not nx.is_directed_acyclic_graph(global_dag): + raise ValueError("AlchemicalNetwork produced a task graph that is not a DAG.") + + return global_dag + + +def build_task_db_from_alchemical_network( + alchemical_network: AlchemicalNetwork, + warehouse: WarehouseBaseClass, + db_path: Path | None = None, + max_tries: int = 1, +) -> exorcist.TaskStatusDB: + """Create an Exorcist TaskStatusDB from an AlchemicalNetwork.""" + if db_path is None: + db_path = Path("tasks.db") + + global_dag = alchemical_network_to_task_graph(alchemical_network, warehouse) + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(global_dag, max_tries) + return db From 0b8edd9718ed01557c147de95e3d0d5c1d738265 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:57:22 -0700 Subject: [PATCH 03/15] test: add tests for warehouse --- src/openfe/tests/storage/test_warehouse.py | 96 +++++++++++++++++++--- 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/src/openfe/tests/storage/test_warehouse.py b/src/openfe/tests/storage/test_warehouse.py index 572769d70..d113cf1d7 100644 --- a/src/openfe/tests/storage/test_warehouse.py +++ b/src/openfe/tests/storage/test_warehouse.py @@ -19,18 +19,35 @@ class TestWarehouseBaseClass: def test_store_protocol_dag_result(self): pytest.skip("Not implemented yet") + @staticmethod + def _build_stores() -> WarehouseStores: + return WarehouseStores( + setup=MemoryStorage(), + result=MemoryStorage(), + shared=MemoryStorage(), + tasks=MemoryStorage(), + ) + + @staticmethod + def _get_protocol_unit(transformation): + dag = transformation.create() + return next(iter(dag.protocol_units)) + @staticmethod def _test_store_load_same_process( - obj, store_func_name, load_func_name, store_name: Literal["setup", "result"] + obj, + store_func_name, + load_func_name, + store_name: Literal["setup", "result", "tasks"], ): - setup_store = MemoryStorage() - result_store = MemoryStorage() - stores = WarehouseStores(setup=setup_store, result=result_store) + stores = TestWarehouseBaseClass._build_stores() client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert setup_store._data == {} - assert result_store._data == {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + assert stores["tasks"]._data == {} store_func(obj) store_under_test: MemoryStorage = stores[store_name] assert store_under_test._data != {} @@ -43,16 +60,16 @@ def _test_store_load_different_process( obj: GufeTokenizable, store_func_name, load_func_name, - store_name: Literal["setup", "result"], + store_name: Literal["setup", "result", "tasks"], ): - setup_store = MemoryStorage() - result_store = MemoryStorage() - stores = WarehouseStores(setup=setup_store, result=result_store) + stores = TestWarehouseBaseClass._build_stores() client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert setup_store._data == {} - assert result_store._data == {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + assert stores["tasks"]._data == {} store_func(obj) store_under_test: MemoryStorage = stores[store_name] assert store_under_test._data != {} @@ -65,6 +82,45 @@ def _test_store_load_different_process( assert reload == obj assert reload is not obj + def test_store_load_task_same_process(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + self._test_store_load_same_process(unit, "store_task", "load_task", "tasks") + + def test_store_load_task_different_process(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + self._test_store_load_different_process(unit, "store_task", "load_task", "tasks") + + def test_store_task_writes_to_tasks_store(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + client.store_task(unit) + + assert stores["tasks"]._data != {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + + def test_exists_finds_task_key(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + + client.store_task(unit) + + assert client.exists(unit.key) + + def test_load_task_returns_object(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + + client.store_task(unit) + loaded = client.load_task(unit.key) + + assert loaded is not None + assert isinstance(loaded, GufeTokenizable) + @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], @@ -164,6 +220,22 @@ def test_store_load_transformation_same_process(self, request, fixture): "load_setup_tokenizable", ) + def test_filesystemwarehouse_has_shared_and_tasks_stores(self, absolute_transformation): + unit = TestWarehouseBaseClass._get_protocol_unit(absolute_transformation) + + with tempfile.TemporaryDirectory() as tmpdir: + client = FileSystemWarehouse(tmpdir) + + assert "shared" in client.stores + assert "tasks" in client.stores + + client.stores["shared"].store_bytes("sentinel", b"shared-data") + with client.stores["shared"].load_stream("sentinel") as f: + assert f.read() == b"shared-data" + + client.store_task(unit) + assert client.exists(unit.key) + @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], From fe87d49cbdf0d62614de168e4044cce243f66bed Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Sat, 7 Feb 2026 15:57:41 -0700 Subject: [PATCH 04/15] fix: can now return protocol unit --- src/openfe/storage/warehouse.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index e0f2c771a..0b46a8c6c 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -90,8 +90,11 @@ def delete(self, store_name: Literal["setup", "result"], location: str): def store_task(self, obj: ProtocolUnit): self._store_gufe_tokenizable("tasks", obj) - def load_task(self, obj: GufeKey): - self._load_gufe_tokenizable(obj) + def load_task(self, obj: GufeKey) -> ProtocolUnit: + unit = self._load_gufe_tokenizable(obj) + if not isinstance(unit, ProtocolUnit): + raise ValueError("Unable to load ProtocolUnit") + return unit def store_setup_tokenizable(self, obj: GufeTokenizable): """Store a GufeTokenizable object in the setup store. From 9c378586770ea93fe87e650b340aed2879e6dee9 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:05:51 -0700 Subject: [PATCH 05/15] refactor: make things more consistent --- src/openfe/orchestration/__init__.py | 4 ++-- src/openfe/orchestration/exorcist_utils.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index d8e31db06..42c169d96 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -32,9 +32,9 @@ def _get_task(self) -> ProtocolUnit: # The format for the taskid is going to "Transformation-:Unit" taskid = db.check_out_task() # Load the unit from warehouse and return - unit = taskid.split(":") + _, protocol_unit_key = taskid.split(":", maxsplit=1) - return self.warehouse.load_task(unit) + return self.warehouse.load_task(GufeKey(protocol_unit_key)) def execute_unit(self, scratch: Path): # 1. Get task/unit diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py index fc30c8c11..51c269b2e 100644 --- a/src/openfe/orchestration/exorcist_utils.py +++ b/src/openfe/orchestration/exorcist_utils.py @@ -18,17 +18,14 @@ def alchemical_network_to_task_graph( for transformation in alchemical_network.edges: dag = transformation.create() for unit in dag.protocol_units: - node_id = f"{transformation.name}-{transformation.key}:{unit.name}-{unit.key}" + node_id = f"{str(transformation.key)}:{str(unit.key)}" global_dag.add_node( node_id, - label=f"{transformation.name}\n{unit.name}", - transformation_key=str(transformation.key), - protocol_unit_key=str(unit.key), ) warehouse.store_task(unit) for u, v in dag.graph.edges: - u_id = f"{transformation.key}:{u.key}" - v_id = f"{transformation.key}:{v.key}" + u_id = f"{str(transformation.key)}:{str(u.key)}" + v_id = f"{str(transformation.key)}:{str(v.key)}" global_dag.add_edge(u_id, v_id) if not nx.is_directed_acyclic_graph(global_dag): From 0b07cbad06daf34204f2d19b500de48bc2cf8df5 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:07:17 -0700 Subject: [PATCH 06/15] test: initial test setup for orchestration subpackage --- src/openfe/tests/orchestration/__init__.py | 2 + src/openfe/tests/orchestration/conftest.py | 118 +++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 src/openfe/tests/orchestration/__init__.py create mode 100644 src/openfe/tests/orchestration/conftest.py diff --git a/src/openfe/tests/orchestration/__init__.py b/src/openfe/tests/orchestration/__init__.py new file mode 100644 index 000000000..efae32ddb --- /dev/null +++ b/src/openfe/tests/orchestration/__init__.py @@ -0,0 +1,2 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe diff --git a/src/openfe/tests/orchestration/conftest.py b/src/openfe/tests/orchestration/conftest.py new file mode 100644 index 000000000..1851b7c05 --- /dev/null +++ b/src/openfe/tests/orchestration/conftest.py @@ -0,0 +1,118 @@ +import gufe +import pytest +from gufe import ChemicalSystem, SolventComponent +from gufe.tests.test_protocol import DummyProtocol +from openff.units import unit + + +@pytest.fixture +def solv_comp(): + yield SolventComponent(positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar) + + +@pytest.fixture +def solvated_complex(T4_protein_component, benzene_transforms, solv_comp): + return ChemicalSystem( + { + "ligand": benzene_transforms["toluene"], + "protein": T4_protein_component, + "solvent": solv_comp, + } + ) + + +@pytest.fixture +def solvated_ligand(benzene_transforms, solv_comp): + return ChemicalSystem( + { + "ligand": benzene_transforms["toluene"], + "solvent": solv_comp, + } + ) + + +@pytest.fixture +def absolute_transformation(solvated_ligand, solvated_complex): + return gufe.Transformation( + solvated_ligand, + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + +@pytest.fixture +def complex_equilibrium(solvated_complex): + return gufe.NonTransformation( + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + ) + + +@pytest.fixture +def benzene_variants_star_map(benzene_transforms, solv_comp, T4_protein_component): + variants = ["toluene", "phenol", "benzonitrile", "anisole", "benzaldehyde", "styrene"] + + # define the solvent chemical systems and transformations between + # benzene and the others + solvated_ligands = {} + solvated_ligand_transformations = {} + + solvated_ligands["benzene"] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, + name="benzene-solvent", + ) + + for ligand in variants: + solvated_ligands[ligand] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-solvent", + ) + + solvated_ligand_transformations[("benzene", ligand)] = gufe.Transformation( + solvated_ligands["benzene"], + solvated_ligands[ligand], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + # define the complex chemical systems and transformations between + # benzene and the others + solvated_complexes = {} + solvated_complex_transformations = {} + + solvated_complexes["benzene"] = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, + name="benzene-complex", + ) + + for ligand in variants: + solvated_complexes[ligand] = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-complex", + ) + solvated_complex_transformations[("benzene", ligand)] = gufe.Transformation( + solvated_complexes["benzene"], + solvated_complexes[ligand], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + return gufe.AlchemicalNetwork( + list(solvated_ligand_transformations.values()) + + list(solvated_complex_transformations.values()) + ) From fcf5e6ae86e26b594b2450dda175fff6663457c1 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:07:46 -0700 Subject: [PATCH 07/15] test: initial exorcist utility testing --- .../orchestration/test_exorcist_utils.py | 210 ++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 src/openfe/tests/orchestration/test_exorcist_utils.py diff --git a/src/openfe/tests/orchestration/test_exorcist_utils.py b/src/openfe/tests/orchestration/test_exorcist_utils.py new file mode 100644 index 000000000..3ae9a85ad --- /dev/null +++ b/src/openfe/tests/orchestration/test_exorcist_utils.py @@ -0,0 +1,210 @@ +from pathlib import Path +from unittest import mock + +import exorcist +import networkx as nx +import pytest +import sqlalchemy as sqla +from gufe.tokenization import GufeKey + +from openfe.orchestration.exorcist_utils import ( + alchemical_network_to_task_graph, + build_task_db_from_alchemical_network, +) +from openfe.storage.warehouse import FileSystemWarehouse + + +class _RecordingWarehouse: + def __init__(self): + self.stored_tasks = [] + + def store_task(self, task): + self.stored_tasks.append(task) + + +def _network_units(benzene_variants_star_map): + units = [] + for transformation in benzene_variants_star_map.edges: + units.extend(transformation.create().protocol_units) + return units + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_stores_all_units(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + expected_units = _network_units(network) + + alchemical_network_to_task_graph(network, warehouse) + + stored_unit_names = [str(unit.name) for unit in warehouse.stored_tasks] + expected_unit_names = [str(unit.name) for unit in expected_units] + + assert len(stored_unit_names) == len(expected_unit_names) + assert sorted(stored_unit_names) == sorted(expected_unit_names) + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_uses_canonical_task_ids(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, warehouse) + + transformation_keys = {str(transformation.key) for transformation in network.edges} + expected_protocol_unit_keys = sorted(str(unit.key) for unit in warehouse.stored_tasks) + observed_protocol_unit_keys = [] + + for node in graph.nodes: + transformation_key, protocol_unit_key = node.split(":", maxsplit=1) + assert transformation_key in transformation_keys + observed_protocol_unit_keys.append(protocol_unit_key) + + assert sorted(observed_protocol_unit_keys) == expected_protocol_unit_keys + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_edges_reference_existing_nodes(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, warehouse) + + assert len(graph.edges) > 0 + for u, v in graph.edges: + assert u in graph.nodes + assert v in graph.nodes + + +def test_alchemical_network_to_task_graph_raises_for_cycle(): + class _Unit: + def __init__(self, name: str, key: str): + self.name = name + self.key = key + + class _Transformation: + name = "cyclic" + key = "Transformation-cycle" + + def create(self): + unit_a = _Unit("unit-a", "ProtocolUnit-a") + unit_b = _Unit("unit-b", "ProtocolUnit-b") + dag = mock.Mock() + dag.protocol_units = [unit_a, unit_b] + dag.graph = nx.DiGraph() + dag.graph.add_nodes_from([unit_a, unit_b]) + dag.graph.add_edges_from([(unit_a, unit_b), (unit_b, unit_a)]) + return dag + + network = mock.Mock() + network.edges = [_Transformation()] + warehouse = mock.Mock() + + with pytest.raises(ValueError, match="not a DAG"): + alchemical_network_to_task_graph(network, warehouse) + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_checkout_order_is_dependency_safe(tmp_path, request, fixture): + network = request.getfixturevalue(fixture) + warehouse = FileSystemWarehouse(str(tmp_path / "warehouse")) + # Build the real sqlite task DB from a real alchemical network fixture. + db = build_task_db_from_alchemical_network( + network, + warehouse, + db_path=tmp_path / "tasks.db", + ) + + # Read task IDs and dependency edges from the persisted DB state. + initial_task_rows = list(db.get_all_tasks()) + graph_taskids = {row.taskid for row in initial_task_rows} + with db.engine.connect() as conn: + dep_rows = conn.execute(sqla.select(db.dependencies_table)).all() + graph_edges = {(row._mapping["from"], row._mapping["to"]) for row in dep_rows} + + checkout_order = [] + # Hard upper bound prevents infinite checkout loops. + max_checkouts = len(graph_taskids) + print(f"Max Checkout={max_checkouts}") + for _ in range(max_checkouts): + taskid = db.check_out_task() + if taskid is None: + break + + checkout_order.append(taskid) + _, protocol_unit_key = taskid.split(":", maxsplit=1) + loaded_unit = warehouse.load_task(GufeKey(protocol_unit_key)) + assert str(loaded_unit.key) == protocol_unit_key + db.mark_task_completed(taskid, success=True) + + # Coverage/completion: every task is checked out exactly once. + observed_taskids = set(checkout_order) + assert observed_taskids == graph_taskids + assert len(checkout_order) == len(graph_taskids) + + # Dependency safety: upstream tasks must appear before downstream tasks. + checkout_index = {taskid: idx for idx, taskid in enumerate(checkout_order)} + for upstream, downstream in graph_edges: + assert checkout_index[upstream] < checkout_index[downstream] + + # Final DB state: all tasks are completed. + task_rows = list(db.get_all_tasks()) + assert len(task_rows) == len(graph_taskids) + assert {row.taskid for row in task_rows} == graph_taskids + assert {row.status for row in task_rows} == {exorcist.TaskStatus.COMPLETED.value} + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_default_path(request, fixture): + network = request.getfixturevalue(fixture) + warehouse = mock.Mock() + fake_graph = nx.DiGraph() + fake_db = mock.Mock() + + with ( + mock.patch( + "openfe.orchestration.exorcist_utils.alchemical_network_to_task_graph", + return_value=fake_graph, + ) as task_graph_mock, + mock.patch( + "openfe.orchestration.exorcist_utils.exorcist.TaskStatusDB.from_filename", + return_value=fake_db, + ) as db_ctor, + ): + result = build_task_db_from_alchemical_network(network, warehouse) + + task_graph_mock.assert_called_once_with(network, warehouse) + db_ctor.assert_called_once_with(Path("tasks.db")) + fake_db.add_task_network.assert_called_once_with(fake_graph, 1) + assert result is fake_db + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_forwards_graph_and_max_tries(request, tmp_path, fixture): + network = request.getfixturevalue(fixture) + warehouse = mock.Mock() + fake_graph = nx.DiGraph() + fake_db = mock.Mock() + db_path = tmp_path / "custom_tasks.db" + + with ( + mock.patch( + "openfe.orchestration.exorcist_utils.alchemical_network_to_task_graph", + return_value=fake_graph, + ) as task_graph_mock, + mock.patch( + "openfe.orchestration.exorcist_utils.exorcist.TaskStatusDB.from_filename", + return_value=fake_db, + ) as db_ctor, + ): + result = build_task_db_from_alchemical_network( + network, + warehouse, + db_path=db_path, + max_tries=7, + ) + + task_graph_mock.assert_called_once_with(network, warehouse) + db_ctor.assert_called_once_with(db_path) + fake_db.add_task_network.assert_called_once_with(fake_graph, 7) + assert result is fake_db From 76401f4eefff2ebd2f5a65cef91e1afbc9205111 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:42:30 -0700 Subject: [PATCH 08/15] refactor: provide a root path to the exorcist DB --- src/openfe/orchestration/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index 42c169d96..98338df02 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -23,12 +23,10 @@ @dataclass class Worker: warehouse: FileSystemWarehouse + task_db_path: Path = Path("./warehouse/tasks.db") def _get_task(self) -> ProtocolUnit: - # Right now, we are just going to assume it exists in the warehouse folder - location = Path("./warehouse/tasks.db") - - db: TaskStatusDB = TaskStatusDB.from_filename(location) + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) # The format for the taskid is going to "Transformation-:Unit" taskid = db.check_out_task() # Load the unit from warehouse and return @@ -43,7 +41,7 @@ def execute_unit(self, scratch: Path): # NOTE: On changes to context, this can easily be replaced with external storage objects # However, to satisfy the current work, we will use this implementation where we # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. - shared_store: FileStorage = self.warehouse.stores["shared"] + shared_store: FileStorage = self.warehouse.shared_store.root_dir shared_root_dir = shared_store.root_dir ctx = Context(scratch, shared=shared_root_dir) results: dict[GufeKey, ProtocolUnitResult] = {} From 815ddbfdf3a3c67c782a482dcab76f30816c4ea0 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Tue, 10 Feb 2026 16:43:01 -0700 Subject: [PATCH 09/15] test: inital worker testing --- src/openfe/tests/orchestration/test_worker.py | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/openfe/tests/orchestration/test_worker.py diff --git a/src/openfe/tests/orchestration/test_worker.py b/src/openfe/tests/orchestration/test_worker.py new file mode 100644 index 000000000..a99f872c0 --- /dev/null +++ b/src/openfe/tests/orchestration/test_worker.py @@ -0,0 +1,123 @@ +from pathlib import Path +from unittest import mock + +import exorcist +import gufe +import networkx as nx +import pytest +from gufe.protocols.protocolunit import ProtocolUnit + +from openfe.orchestration import Worker +from openfe.orchestration.exorcist_utils import build_task_db_from_alchemical_network +from openfe.storage.warehouse import FileSystemWarehouse + + +def _result_store_files(warehouse: FileSystemWarehouse) -> set[str]: + result_root = Path(warehouse.result_store.root_dir) + return {str(path.relative_to(result_root)) for path in result_root.rglob("*") if path.is_file()} + + +def _contains_protocol_unit(value) -> bool: + if isinstance(value, ProtocolUnit): + return True + if isinstance(value, dict): + return any(_contains_protocol_unit(item) for item in value.values()) + if isinstance(value, list): + return any(_contains_protocol_unit(item) for item in value) + return False + + +def _get_dependency_free_unit(absolute_transformation): + for unit in absolute_transformation.create().protocol_units: + if not _contains_protocol_unit(unit.inputs): + return unit + raise ValueError("No dependency-free protocol unit found for execution test setup.") + + +@pytest.fixture +def worker_with_real_db(tmp_path, absolute_transformation): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + network = gufe.AlchemicalNetwork([absolute_transformation]) + db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + return worker, warehouse, db + + +@pytest.fixture +def worker_with_executable_task_db(tmp_path, absolute_transformation): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + unit = _get_dependency_free_unit(absolute_transformation) + warehouse.store_task(unit) + + taskid = f"{absolute_transformation.key}:{unit.key}" + task_graph = nx.DiGraph() + task_graph.add_node(taskid) + + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(task_graph, 1) + + worker = Worker(warehouse=warehouse, task_db_path=db_path) + return worker, warehouse, db, unit + + +def test_get_task_uses_default_db_path_without_patching( + tmp_path, monkeypatch, absolute_transformation +): + monkeypatch.chdir(tmp_path) + warehouse = FileSystemWarehouse("warehouse") + db_path = Path("warehouse/tasks.db") + network = gufe.AlchemicalNetwork([absolute_transformation]) + db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) + + worker = Worker(warehouse=warehouse) + loaded = worker._get_task() + + expected_keys = {task_row.taskid.split(":", maxsplit=1)[1] for task_row in db.get_all_tasks()} + assert worker.task_db_path == Path("./warehouse/tasks.db") + assert str(loaded.key) in expected_keys + + +def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_real_db): + worker, warehouse, db = worker_with_real_db + + task_ids = [row.taskid for row in db.get_all_tasks()] + expected_protocol_unit_keys = {task_id.split(":", maxsplit=1)[1] for task_id in task_ids} + + loaded = worker._get_task() + reloaded = warehouse.load_task(loaded.key) + + assert str(loaded.key) in expected_protocol_unit_keys + assert loaded == reloaded + + +def test_execute_unit_stores_real_result(worker_with_executable_task_db, tmp_path): + worker, warehouse, _, _ = worker_with_executable_task_db + before = _result_store_files(warehouse) + + worker.execute_unit(scratch=tmp_path / "scratch") + + after = _result_store_files(warehouse) + assert len(after) > len(before) + + +def test_execute_unit_propagates_execute_error_without_store( + worker_with_executable_task_db, tmp_path +): + worker, warehouse, _, unit = worker_with_executable_task_db + before = _result_store_files(warehouse) + + with mock.patch.object( + type(unit), + "execute", + autospec=True, + side_effect=RuntimeError("unit execution failed"), + ): + with pytest.raises(RuntimeError, match="unit execution failed"): + worker.execute_unit(scratch=tmp_path / "scratch") + + after = _result_store_files(warehouse) + assert after == before From 98703f14f36fe258a619e535c18758b8ea3b3963 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:08:27 -0700 Subject: [PATCH 10/15] feat: add shared_store --- src/openfe/storage/warehouse.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index 0b46a8c6c..843dfc9f0 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -314,6 +314,17 @@ def result_store(self): """ return self.stores["result"] + @property + def shared_store(self): + """Get the shared store. + + Returns + ------- + ExternalStorage + The shared storage location + """ + return self.stores["shared"] + class FileSystemWarehouse(WarehouseBaseClass): """Warehouse implementation using local filesystem storage. From cedc92fc754c06c0aaa9a24413be16ef4a18a073 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:10:09 -0700 Subject: [PATCH 11/15] feat: add better handling for CLI application Signed-off-by: Ethan Holz --- src/openfe/orchestration/__init__.py | 39 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index 98338df02..9c8410843 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -6,7 +6,6 @@ from gufe.protocols.protocolunit import ( Context, ProtocolUnit, - ProtocolUnitFailure, ProtocolUnitResult, ) from gufe.storage.externalresource.filestorage import FileStorage @@ -25,31 +24,49 @@ class Worker: warehouse: FileSystemWarehouse task_db_path: Path = Path("./warehouse/tasks.db") - def _get_task(self) -> ProtocolUnit: + def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) - # The format for the taskid is going to "Transformation-:Unit" + # The format for the taskid is "Transformation-:ProtocolUnit-" taskid = db.check_out_task() - # Load the unit from warehouse and return + if taskid is None: + return None + _, protocol_unit_key = taskid.split(":", maxsplit=1) + unit = self.warehouse.load_task(GufeKey(protocol_unit_key)) + return taskid, unit - return self.warehouse.load_task(GufeKey(protocol_unit_key)) + def _get_task(self) -> ProtocolUnit: + task = self._checkout_task() + if task is None: + raise RuntimeError("No AVAILABLE tasks found in the task database.") + _, unit = task + return unit - def execute_unit(self, scratch: Path): + def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: # 1. Get task/unit - unit = self._get_task() + task = self._checkout_task() + if task is None: + return None + taskid, unit = task # 2. Constrcut the context # NOTE: On changes to context, this can easily be replaced with external storage objects # However, to satisfy the current work, we will use this implementation where we # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. - shared_store: FileStorage = self.warehouse.shared_store.root_dir + shared_store: FileStorage = self.warehouse.stores["shared"] shared_root_dir = shared_store.root_dir ctx = Context(scratch, shared=shared_root_dir) results: dict[GufeKey, ProtocolUnitResult] = {} inputs = _pu_to_pur(unit.inputs, results) + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) # 3. Execute unit - result = unit.execute(context=ctx, **inputs) - # if not result.ok(): - # Increment attempt in taskdb + try: + result = unit.execute(context=ctx, **inputs) + except Exception: + db.mark_task_completed(taskid, success=False) + raise + + db.mark_task_completed(taskid, success=result.ok()) # 4. output result to warehouse # TODO: we may need to end up handling namespacing on the warehouse side for tokenizables self.warehouse.store_result_tokenizable(result) + return taskid, result From 5a51d3f52203ed33794471a15a10015f092c64b2 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:21:16 -0700 Subject: [PATCH 12/15] test: add new worker tests --- src/openfe/tests/orchestration/test_worker.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/src/openfe/tests/orchestration/test_worker.py b/src/openfe/tests/orchestration/test_worker.py index a99f872c0..dca306d26 100644 --- a/src/openfe/tests/orchestration/test_worker.py +++ b/src/openfe/tests/orchestration/test_worker.py @@ -95,20 +95,26 @@ def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_r def test_execute_unit_stores_real_result(worker_with_executable_task_db, tmp_path): - worker, warehouse, _, _ = worker_with_executable_task_db + worker, warehouse, db, _ = worker_with_executable_task_db before = _result_store_files(warehouse) - worker.execute_unit(scratch=tmp_path / "scratch") + execution = worker.execute_unit(scratch=tmp_path / "scratch") + assert execution is not None + taskid, _ = execution after = _result_store_files(warehouse) assert len(after) > len(before) + rows = list(db.get_all_tasks()) + status_by_taskid = {row.taskid: row.status for row in rows} + assert status_by_taskid[taskid] == exorcist.TaskStatus.COMPLETED.value def test_execute_unit_propagates_execute_error_without_store( worker_with_executable_task_db, tmp_path ): - worker, warehouse, _, unit = worker_with_executable_task_db + worker, warehouse, db, unit = worker_with_executable_task_db before = _result_store_files(warehouse) + taskid = list(db.get_all_tasks())[0].taskid with mock.patch.object( type(unit), @@ -121,3 +127,28 @@ def test_execute_unit_propagates_execute_error_without_store( after = _result_store_files(warehouse) assert after == before + rows = list(db.get_all_tasks()) + status_by_taskid = {row.taskid: row.status for row in rows} + assert status_by_taskid[taskid] == exorcist.TaskStatus.TOO_MANY_RETRIES.value + + +def test_checkout_task_returns_none_when_no_available_tasks(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse_root.mkdir(parents=True, exist_ok=True) + warehouse = FileSystemWarehouse(str(warehouse_root)) + exorcist.TaskStatusDB.from_filename(db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + assert worker._checkout_task() is None + + +def test_execute_unit_returns_none_when_no_available_tasks(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse_root.mkdir(parents=True, exist_ok=True) + warehouse = FileSystemWarehouse(str(warehouse_root)) + exorcist.TaskStatusDB.from_filename(db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + assert worker.execute_unit(scratch=tmp_path / "scratch") is None From 6405ea2d31b2743601a27a13d470c0e341d8c9eb Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:22:51 -0700 Subject: [PATCH 13/15] feat: add exorcist worker to CLI --- src/openfecli/commands/worker.py | 83 ++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 src/openfecli/commands/worker.py diff --git a/src/openfecli/commands/worker.py b/src/openfecli/commands/worker.py new file mode 100644 index 000000000..dbc52b6b3 --- /dev/null +++ b/src/openfecli/commands/worker.py @@ -0,0 +1,83 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pathlib + +import click + +from openfecli import OFECommandPlugin +from openfecli.utils import print_duration, write + + +def _build_worker(warehouse_path: pathlib.Path, db_path: pathlib.Path): + from openfe.orchestration import Worker + from openfe.storage.warehouse import FileSystemWarehouse + + warehouse = FileSystemWarehouse(str(warehouse_path)) + return Worker(warehouse=warehouse, task_db_path=db_path) + + +def worker_main(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + db_path = warehouse_path / "tasks.db" + if not db_path.is_file(): + raise click.ClickException(f"Task database not found at: {db_path}") + + if scratch is None: + scratch = pathlib.Path.cwd() + + scratch.mkdir(parents=True, exist_ok=True) + + worker = _build_worker(warehouse_path, db_path) + + try: + execution = worker.execute_unit(scratch=scratch) + except Exception as exc: + raise click.ClickException(f"Task execution failed: {exc}") from exc + + if execution is None: + write("No available task in task graph.") + return None + + taskid, result = execution + if not result.ok(): + raise click.ClickException(f"Task '{taskid}' returned a failure result.") + + write(f"Completed task: {taskid}") + return result + + +@click.command("worker", short_help="Execute one available task from a filesystem warehouse") +@click.argument( + "warehouse_path", + type=click.Path( + exists=True, + readable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), +) +@click.option( + "--scratch", + "-s", + default=None, + type=click.Path( + writable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), + help="Directory for scratch files. Defaults to current working directory.", +) +@print_duration +def worker(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + """ + Execute one available task from a warehouse task graph. + + The warehouse directory must contain a ``tasks.db`` task database and task + payloads under ``tasks/`` created via OpenFE orchestration setup. + """ + worker_main(warehouse_path=warehouse_path, scratch=scratch) + + +PLUGIN = OFECommandPlugin(command=worker, section="Quickrun Executor", requires_ofe=(0, 3)) From d137d81bbb0ded006ef42e06429269b6732f2aea Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 09:24:15 -0700 Subject: [PATCH 14/15] test: add for worker CLI command --- src/openfecli/tests/commands/test_worker.py | 107 ++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/openfecli/tests/commands/test_worker.py diff --git a/src/openfecli/tests/commands/test_worker.py b/src/openfecli/tests/commands/test_worker.py new file mode 100644 index 000000000..6d3b55f7e --- /dev/null +++ b/src/openfecli/tests/commands/test_worker.py @@ -0,0 +1,107 @@ +from pathlib import Path +from unittest import mock + +from click.testing import CliRunner + +from openfecli.commands.worker import worker + + +class _SuccessfulResult: + def ok(self): + return True + + +class _FailedResult: + def ok(self): + return False + + +def test_worker_requires_task_database(): + runner = CliRunner() + with runner.isolated_filesystem(): + Path("warehouse").mkdir() + result = runner.invoke(worker, ["warehouse"]) + assert result.exit_code == 1 + assert "Task database not found at" in result.output + + +def test_worker_no_available_task_exits_zero(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = None + + with mock.patch( + "openfecli.commands.worker._build_worker", return_value=mock_worker + ) as build_worker: + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 0 + assert "No available task in task graph." in result.output + build_worker.assert_called_once_with(warehouse_path, warehouse_path / "tasks.db") + kwargs = mock_worker.execute_unit.call_args.kwargs + assert kwargs["scratch"] == Path.cwd() + + +def test_worker_executes_one_task_and_reports_completion(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _SuccessfulResult(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse", "--scratch", "scratch"]) + + assert result.exit_code == 0 + assert "Completed task: Transformation-abc:ProtocolUnit-def" in result.output + assert Path("scratch").is_dir() + kwargs = mock_worker.execute_unit.call_args.kwargs + assert kwargs["scratch"] == Path("scratch") + + +def test_worker_raises_when_result_is_failure(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _FailedResult(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert "returned a failure result" in result.output + + +def test_worker_raises_when_execution_throws(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.side_effect = RuntimeError("boom") + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert "Task execution failed: boom" in result.output From 946a9bee880cf17d219a2d0b0c3a9c3706c3a1eb Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 11 Feb 2026 20:29:28 -0700 Subject: [PATCH 15/15] docs: add numpy docstrings --- src/openfe/orchestration/__init__.py | 54 ++++++++++++++++++++++ src/openfe/orchestration/exorcist_utils.py | 51 ++++++++++++++++++-- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index 9c8410843..c46809f44 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -1,3 +1,5 @@ +"""Task orchestration utilities backed by Exorcist and a warehouse.""" + from dataclasses import dataclass from pathlib import Path @@ -21,10 +23,29 @@ @dataclass class Worker: + """Execute protocol units from an Exorcist task database. + + Parameters + ---------- + warehouse : FileSystemWarehouse + Warehouse used to load queued tasks and store execution results. + task_db_path : pathlib.Path, default=Path("./warehouse/tasks.db") + Path to the Exorcist SQLite task database. + """ + warehouse: FileSystemWarehouse task_db_path: Path = Path("./warehouse/tasks.db") def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: + """Check out one available task and load its protocol unit. + + Returns + ------- + tuple[str, ProtocolUnit] or None + The checked-out task ID and corresponding protocol unit, or + ``None`` if no task is currently available. + """ + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) # The format for the taskid is "Transformation-:ProtocolUnit-" taskid = db.check_out_task() @@ -36,6 +57,19 @@ def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: return taskid, unit def _get_task(self) -> ProtocolUnit: + """Return the next available protocol unit. + + Returns + ------- + ProtocolUnit + A protocol unit loaded from the warehouse. + + Raises + ------ + RuntimeError + Raised when no task is available in the task database. + """ + task = self._checkout_task() if task is None: raise RuntimeError("No AVAILABLE tasks found in the task database.") @@ -43,6 +77,26 @@ def _get_task(self) -> ProtocolUnit: return unit def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: + """Execute one checked-out protocol unit and persist its result. + + Parameters + ---------- + scratch : pathlib.Path + Scratch directory passed to the protocol execution context. + + Returns + ------- + tuple[str, ProtocolUnitResult] or None + The task ID and execution result for the processed task, or + ``None`` if no task is currently available. + + Raises + ------ + Exception + Re-raises any exception thrown during protocol unit execution after + marking the task as failed. + """ + # 1. Get task/unit task = self._checkout_task() if task is None: diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py index 51c269b2e..2ce8f1e9c 100644 --- a/src/openfe/orchestration/exorcist_utils.py +++ b/src/openfe/orchestration/exorcist_utils.py @@ -1,4 +1,8 @@ -"""Utilities for building Exorcist task graphs and task databases.""" +"""Utilities for building Exorcist task graphs and task databases. + +This module translates an :class:`gufe.AlchemicalNetwork` into Exorcist task +structures and can initialize an Exorcist task database from that graph. +""" from pathlib import Path @@ -12,7 +16,28 @@ def alchemical_network_to_task_graph( alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass ) -> nx.DiGraph: - """Build a global task DAG from an AlchemicalNetwork.""" + """Build a global task DAG from an alchemical network. + + Parameters + ---------- + alchemical_network : AlchemicalNetwork + Network containing transformations to execute. + warehouse : WarehouseBaseClass + Warehouse used to persist protocol units as tasks while the graph is + constructed. + + Returns + ------- + nx.DiGraph + A directed acyclic graph where each node is a task ID in the form + ``":"`` and edges encode + protocol-unit dependencies. + + Raises + ------ + ValueError + Raised if the assembled task graph is not acyclic. + """ global_dag = nx.DiGraph() for transformation in alchemical_network.edges: @@ -40,7 +65,27 @@ def build_task_db_from_alchemical_network( db_path: Path | None = None, max_tries: int = 1, ) -> exorcist.TaskStatusDB: - """Create an Exorcist TaskStatusDB from an AlchemicalNetwork.""" + """Create and populate a task database from an alchemical network. + + Parameters + ---------- + alchemical_network : AlchemicalNetwork + Network containing transformations to convert into task records. + warehouse : WarehouseBaseClass + Warehouse used to persist protocol units while building the task DAG. + db_path : pathlib.Path or None, optional + Location of the SQLite-backed Exorcist database. If ``None``, defaults + to ``Path("tasks.db")`` in the current working directory. + max_tries : int, default=1 + Maximum number of retries for each task before Exorcist marks it as + ``TOO_MANY_RETRIES``. + + Returns + ------- + exorcist.TaskStatusDB + Initialized task database populated with graph nodes and dependency + edges derived from ``alchemical_network``. + """ if db_path is None: db_path = Path("tasks.db")