diff --git a/environment.yml b/environment.yml index 0d563c254..716bc0f05 100644 --- a/environment.yml +++ b/environment.yml @@ -53,3 +53,4 @@ dependencies: - pip: - git+https://github.com/OpenFreeEnergy/gufe@main - git+https://github.com/fatiando/pooch@main # related to https://github.com/fatiando/pooch/issues/502 + - git+https://github.com/OpenFreeEnergy/exorcist@main diff --git a/src/openfe/orchestration/__init__.py b/src/openfe/orchestration/__init__.py index e69de29bb..c46809f44 100644 --- a/src/openfe/orchestration/__init__.py +++ b/src/openfe/orchestration/__init__.py @@ -0,0 +1,126 @@ +"""Task orchestration utilities backed by Exorcist and a warehouse.""" + +from dataclasses import dataclass +from pathlib import Path + +from exorcist.taskdb import TaskStatusDB +from gufe.protocols.protocoldag import _pu_to_pur +from gufe.protocols.protocolunit import ( + Context, + ProtocolUnit, + ProtocolUnitResult, +) +from gufe.storage.externalresource.filestorage import FileStorage +from gufe.tokenization import GufeKey + +from openfe.storage.warehouse import FileSystemWarehouse + +from .exorcist_utils import ( + alchemical_network_to_task_graph, + build_task_db_from_alchemical_network, +) + + +@dataclass +class Worker: + """Execute protocol units from an Exorcist task database. + + Parameters + ---------- + warehouse : FileSystemWarehouse + Warehouse used to load queued tasks and store execution results. + task_db_path : pathlib.Path, default=Path("./warehouse/tasks.db") + Path to the Exorcist SQLite task database. + """ + + warehouse: FileSystemWarehouse + task_db_path: Path = Path("./warehouse/tasks.db") + + def _checkout_task(self) -> tuple[str, ProtocolUnit] | None: + """Check out one available task and load its protocol unit. + + Returns + ------- + tuple[str, ProtocolUnit] or None + The checked-out task ID and corresponding protocol unit, or + ``None`` if no task is currently available. + """ + + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) + # The format for the taskid is "Transformation-:ProtocolUnit-" + taskid = db.check_out_task() + if taskid is None: + return None + + _, protocol_unit_key = taskid.split(":", maxsplit=1) + unit = self.warehouse.load_task(GufeKey(protocol_unit_key)) + return taskid, unit + + def _get_task(self) -> ProtocolUnit: + """Return the next available protocol unit. + + Returns + ------- + ProtocolUnit + A protocol unit loaded from the warehouse. + + Raises + ------ + RuntimeError + Raised when no task is available in the task database. + """ + + task = self._checkout_task() + if task is None: + raise RuntimeError("No AVAILABLE tasks found in the task database.") + _, unit = task + return unit + + def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None: + """Execute one checked-out protocol unit and persist its result. + + Parameters + ---------- + scratch : pathlib.Path + Scratch directory passed to the protocol execution context. + + Returns + ------- + tuple[str, ProtocolUnitResult] or None + The task ID and execution result for the processed task, or + ``None`` if no task is currently available. + + Raises + ------ + Exception + Re-raises any exception thrown during protocol unit execution after + marking the task as failed. + """ + + # 1. Get task/unit + task = self._checkout_task() + if task is None: + return None + taskid, unit = task + # 2. Constrcut the context + # NOTE: On changes to context, this can easily be replaced with external storage objects + # However, to satisfy the current work, we will use this implementation where we + # force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage. + shared_store: FileStorage = self.warehouse.stores["shared"] + shared_root_dir = shared_store.root_dir + ctx = Context(scratch, shared=shared_root_dir) + results: dict[GufeKey, ProtocolUnitResult] = {} + inputs = _pu_to_pur(unit.inputs, results) + db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path) + # 3. Execute unit + try: + result = unit.execute(context=ctx, **inputs) + except Exception: + db.mark_task_completed(taskid, success=False) + raise + + db.mark_task_completed(taskid, success=result.ok()) + # 4. output result to warehouse + # TODO: we may need to end up handling namespacing on the warehouse side for tokenizables + self.warehouse.store_result_tokenizable(result) + return taskid, result diff --git a/src/openfe/orchestration/exorcist_utils.py b/src/openfe/orchestration/exorcist_utils.py new file mode 100644 index 000000000..2ce8f1e9c --- /dev/null +++ b/src/openfe/orchestration/exorcist_utils.py @@ -0,0 +1,95 @@ +"""Utilities for building Exorcist task graphs and task databases. + +This module translates an :class:`gufe.AlchemicalNetwork` into Exorcist task +structures and can initialize an Exorcist task database from that graph. +""" + +from pathlib import Path + +import exorcist +import networkx as nx +from gufe import AlchemicalNetwork + +from openfe.storage.warehouse import WarehouseBaseClass + + +def alchemical_network_to_task_graph( + alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass +) -> nx.DiGraph: + """Build a global task DAG from an alchemical network. + + Parameters + ---------- + alchemical_network : AlchemicalNetwork + Network containing transformations to execute. + warehouse : WarehouseBaseClass + Warehouse used to persist protocol units as tasks while the graph is + constructed. + + Returns + ------- + nx.DiGraph + A directed acyclic graph where each node is a task ID in the form + ``":"`` and edges encode + protocol-unit dependencies. + + Raises + ------ + ValueError + Raised if the assembled task graph is not acyclic. + """ + + global_dag = nx.DiGraph() + for transformation in alchemical_network.edges: + dag = transformation.create() + for unit in dag.protocol_units: + node_id = f"{str(transformation.key)}:{str(unit.key)}" + global_dag.add_node( + node_id, + ) + warehouse.store_task(unit) + for u, v in dag.graph.edges: + u_id = f"{str(transformation.key)}:{str(u.key)}" + v_id = f"{str(transformation.key)}:{str(v.key)}" + global_dag.add_edge(u_id, v_id) + + if not nx.is_directed_acyclic_graph(global_dag): + raise ValueError("AlchemicalNetwork produced a task graph that is not a DAG.") + + return global_dag + + +def build_task_db_from_alchemical_network( + alchemical_network: AlchemicalNetwork, + warehouse: WarehouseBaseClass, + db_path: Path | None = None, + max_tries: int = 1, +) -> exorcist.TaskStatusDB: + """Create and populate a task database from an alchemical network. + + Parameters + ---------- + alchemical_network : AlchemicalNetwork + Network containing transformations to convert into task records. + warehouse : WarehouseBaseClass + Warehouse used to persist protocol units while building the task DAG. + db_path : pathlib.Path or None, optional + Location of the SQLite-backed Exorcist database. If ``None``, defaults + to ``Path("tasks.db")`` in the current working directory. + max_tries : int, default=1 + Maximum number of retries for each task before Exorcist marks it as + ``TOO_MANY_RETRIES``. + + Returns + ------- + exorcist.TaskStatusDB + Initialized task database populated with graph nodes and dependency + edges derived from ``alchemical_network``. + """ + if db_path is None: + db_path = Path("tasks.db") + + global_dag = alchemical_network_to_task_graph(alchemical_network, warehouse) + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(global_dag, max_tries) + return db diff --git a/src/openfe/storage/warehouse.py b/src/openfe/storage/warehouse.py index c10baa12c..843dfc9f0 100644 --- a/src/openfe/storage/warehouse.py +++ b/src/openfe/storage/warehouse.py @@ -5,6 +5,8 @@ import re from typing import Literal, TypedDict +from gufe.protocols.protocoldag import ProtocolDAG +from gufe.protocols.protocolunit import ProtocolUnit from gufe.storage.externalresource import ExternalStorage, FileStorage from gufe.tokenization import ( JSON_HANDLER, @@ -35,6 +37,8 @@ class WarehouseStores(TypedDict): setup: ExternalStorage result: ExternalStorage + shared: ExternalStorage + tasks: ExternalStorage class WarehouseBaseClass: @@ -83,6 +87,15 @@ def delete(self, store_name: Literal["setup", "result"], location: str): store: ExternalStorage = self.stores[store_name] store.delete(location) + def store_task(self, obj: ProtocolUnit): + self._store_gufe_tokenizable("tasks", obj) + + def load_task(self, obj: GufeKey) -> ProtocolUnit: + unit = self._load_gufe_tokenizable(obj) + if not isinstance(unit, ProtocolUnit): + raise ValueError("Unable to load ProtocolUnit") + return unit + def store_setup_tokenizable(self, obj: GufeTokenizable): """Store a GufeTokenizable object in the setup store. @@ -134,7 +147,7 @@ def load_result_tokenizable(self, obj: GufeKey) -> GufeTokenizable: return self._load_gufe_tokenizable(gufe_key=obj) def exists(self, key: GufeKey) -> bool: - """Check if an object with the given key exists in any store. + """Check if an object with the given key exists in any store that holds tokenizables. Parameters ---------- @@ -171,7 +184,12 @@ def _get_store_for_key(self, key: GufeKey) -> ExternalStorage: return self.stores[name] raise ValueError(f"GufeKey {key} is not stored") - def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: GufeTokenizable): + def _store_gufe_tokenizable( + self, + store_name: Literal["setup", "result", "tasks"], + obj: GufeTokenizable, + name: str | None = None, + ): """Store a GufeTokenizable object with deduplication. Parameters @@ -197,7 +215,10 @@ def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: G data = json.dumps(keyed_dict, cls=JSON_HANDLER.encoder, sort_keys=True).encode( "utf-8" ) - target.store_bytes(gufe_key, data) + if name: + target.store_bytes(name, data) + else: + target.store_bytes(gufe_key, data) def _load_gufe_tokenizable(self, gufe_key: GufeKey) -> GufeTokenizable: """Load a deduplicated object from a GufeKey. @@ -293,6 +314,17 @@ def result_store(self): """ return self.stores["result"] + @property + def shared_store(self): + """Get the shared store. + + Returns + ------- + ExternalStorage + The shared storage location + """ + return self.stores["shared"] + class FileSystemWarehouse(WarehouseBaseClass): """Warehouse implementation using local filesystem storage. @@ -315,5 +347,9 @@ class FileSystemWarehouse(WarehouseBaseClass): def __init__(self, root_dir: str = "warehouse"): setup_store = FileStorage(f"{root_dir}/setup") result_store = FileStorage(f"{root_dir}/result") - stores = WarehouseStores(setup=setup_store, result=result_store) + shared_store = FileStorage(f"{root_dir}/shared") + tasks_store = FileStorage(f"{root_dir}/tasks") + stores = WarehouseStores( + setup=setup_store, result=result_store, shared=shared_store, tasks=tasks_store + ) super().__init__(stores) diff --git a/src/openfe/tests/orchestration/__init__.py b/src/openfe/tests/orchestration/__init__.py new file mode 100644 index 000000000..efae32ddb --- /dev/null +++ b/src/openfe/tests/orchestration/__init__.py @@ -0,0 +1,2 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe diff --git a/src/openfe/tests/orchestration/conftest.py b/src/openfe/tests/orchestration/conftest.py new file mode 100644 index 000000000..1851b7c05 --- /dev/null +++ b/src/openfe/tests/orchestration/conftest.py @@ -0,0 +1,118 @@ +import gufe +import pytest +from gufe import ChemicalSystem, SolventComponent +from gufe.tests.test_protocol import DummyProtocol +from openff.units import unit + + +@pytest.fixture +def solv_comp(): + yield SolventComponent(positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar) + + +@pytest.fixture +def solvated_complex(T4_protein_component, benzene_transforms, solv_comp): + return ChemicalSystem( + { + "ligand": benzene_transforms["toluene"], + "protein": T4_protein_component, + "solvent": solv_comp, + } + ) + + +@pytest.fixture +def solvated_ligand(benzene_transforms, solv_comp): + return ChemicalSystem( + { + "ligand": benzene_transforms["toluene"], + "solvent": solv_comp, + } + ) + + +@pytest.fixture +def absolute_transformation(solvated_ligand, solvated_complex): + return gufe.Transformation( + solvated_ligand, + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + +@pytest.fixture +def complex_equilibrium(solvated_complex): + return gufe.NonTransformation( + solvated_complex, + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + ) + + +@pytest.fixture +def benzene_variants_star_map(benzene_transforms, solv_comp, T4_protein_component): + variants = ["toluene", "phenol", "benzonitrile", "anisole", "benzaldehyde", "styrene"] + + # define the solvent chemical systems and transformations between + # benzene and the others + solvated_ligands = {} + solvated_ligand_transformations = {} + + solvated_ligands["benzene"] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, + name="benzene-solvent", + ) + + for ligand in variants: + solvated_ligands[ligand] = ChemicalSystem( + { + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-solvent", + ) + + solvated_ligand_transformations[("benzene", ligand)] = gufe.Transformation( + solvated_ligands["benzene"], + solvated_ligands[ligand], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + # define the complex chemical systems and transformations between + # benzene and the others + solvated_complexes = {} + solvated_complex_transformations = {} + + solvated_complexes["benzene"] = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": solv_comp, + "ligand": benzene_transforms["benzene"], + }, + name="benzene-complex", + ) + + for ligand in variants: + solvated_complexes[ligand] = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": solv_comp, + "ligand": benzene_transforms[ligand], + }, + name=f"{ligand}-complex", + ) + solvated_complex_transformations[("benzene", ligand)] = gufe.Transformation( + solvated_complexes["benzene"], + solvated_complexes[ligand], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + return gufe.AlchemicalNetwork( + list(solvated_ligand_transformations.values()) + + list(solvated_complex_transformations.values()) + ) diff --git a/src/openfe/tests/orchestration/test_exorcist_utils.py b/src/openfe/tests/orchestration/test_exorcist_utils.py new file mode 100644 index 000000000..3ae9a85ad --- /dev/null +++ b/src/openfe/tests/orchestration/test_exorcist_utils.py @@ -0,0 +1,210 @@ +from pathlib import Path +from unittest import mock + +import exorcist +import networkx as nx +import pytest +import sqlalchemy as sqla +from gufe.tokenization import GufeKey + +from openfe.orchestration.exorcist_utils import ( + alchemical_network_to_task_graph, + build_task_db_from_alchemical_network, +) +from openfe.storage.warehouse import FileSystemWarehouse + + +class _RecordingWarehouse: + def __init__(self): + self.stored_tasks = [] + + def store_task(self, task): + self.stored_tasks.append(task) + + +def _network_units(benzene_variants_star_map): + units = [] + for transformation in benzene_variants_star_map.edges: + units.extend(transformation.create().protocol_units) + return units + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_stores_all_units(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + expected_units = _network_units(network) + + alchemical_network_to_task_graph(network, warehouse) + + stored_unit_names = [str(unit.name) for unit in warehouse.stored_tasks] + expected_unit_names = [str(unit.name) for unit in expected_units] + + assert len(stored_unit_names) == len(expected_unit_names) + assert sorted(stored_unit_names) == sorted(expected_unit_names) + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_uses_canonical_task_ids(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, warehouse) + + transformation_keys = {str(transformation.key) for transformation in network.edges} + expected_protocol_unit_keys = sorted(str(unit.key) for unit in warehouse.stored_tasks) + observed_protocol_unit_keys = [] + + for node in graph.nodes: + transformation_key, protocol_unit_key = node.split(":", maxsplit=1) + assert transformation_key in transformation_keys + observed_protocol_unit_keys.append(protocol_unit_key) + + assert sorted(observed_protocol_unit_keys) == expected_protocol_unit_keys + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_alchemical_network_to_task_graph_edges_reference_existing_nodes(request, fixture): + warehouse = _RecordingWarehouse() + network = request.getfixturevalue(fixture) + + graph = alchemical_network_to_task_graph(network, warehouse) + + assert len(graph.edges) > 0 + for u, v in graph.edges: + assert u in graph.nodes + assert v in graph.nodes + + +def test_alchemical_network_to_task_graph_raises_for_cycle(): + class _Unit: + def __init__(self, name: str, key: str): + self.name = name + self.key = key + + class _Transformation: + name = "cyclic" + key = "Transformation-cycle" + + def create(self): + unit_a = _Unit("unit-a", "ProtocolUnit-a") + unit_b = _Unit("unit-b", "ProtocolUnit-b") + dag = mock.Mock() + dag.protocol_units = [unit_a, unit_b] + dag.graph = nx.DiGraph() + dag.graph.add_nodes_from([unit_a, unit_b]) + dag.graph.add_edges_from([(unit_a, unit_b), (unit_b, unit_a)]) + return dag + + network = mock.Mock() + network.edges = [_Transformation()] + warehouse = mock.Mock() + + with pytest.raises(ValueError, match="not a DAG"): + alchemical_network_to_task_graph(network, warehouse) + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_checkout_order_is_dependency_safe(tmp_path, request, fixture): + network = request.getfixturevalue(fixture) + warehouse = FileSystemWarehouse(str(tmp_path / "warehouse")) + # Build the real sqlite task DB from a real alchemical network fixture. + db = build_task_db_from_alchemical_network( + network, + warehouse, + db_path=tmp_path / "tasks.db", + ) + + # Read task IDs and dependency edges from the persisted DB state. + initial_task_rows = list(db.get_all_tasks()) + graph_taskids = {row.taskid for row in initial_task_rows} + with db.engine.connect() as conn: + dep_rows = conn.execute(sqla.select(db.dependencies_table)).all() + graph_edges = {(row._mapping["from"], row._mapping["to"]) for row in dep_rows} + + checkout_order = [] + # Hard upper bound prevents infinite checkout loops. + max_checkouts = len(graph_taskids) + print(f"Max Checkout={max_checkouts}") + for _ in range(max_checkouts): + taskid = db.check_out_task() + if taskid is None: + break + + checkout_order.append(taskid) + _, protocol_unit_key = taskid.split(":", maxsplit=1) + loaded_unit = warehouse.load_task(GufeKey(protocol_unit_key)) + assert str(loaded_unit.key) == protocol_unit_key + db.mark_task_completed(taskid, success=True) + + # Coverage/completion: every task is checked out exactly once. + observed_taskids = set(checkout_order) + assert observed_taskids == graph_taskids + assert len(checkout_order) == len(graph_taskids) + + # Dependency safety: upstream tasks must appear before downstream tasks. + checkout_index = {taskid: idx for idx, taskid in enumerate(checkout_order)} + for upstream, downstream in graph_edges: + assert checkout_index[upstream] < checkout_index[downstream] + + # Final DB state: all tasks are completed. + task_rows = list(db.get_all_tasks()) + assert len(task_rows) == len(graph_taskids) + assert {row.taskid for row in task_rows} == graph_taskids + assert {row.status for row in task_rows} == {exorcist.TaskStatus.COMPLETED.value} + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_default_path(request, fixture): + network = request.getfixturevalue(fixture) + warehouse = mock.Mock() + fake_graph = nx.DiGraph() + fake_db = mock.Mock() + + with ( + mock.patch( + "openfe.orchestration.exorcist_utils.alchemical_network_to_task_graph", + return_value=fake_graph, + ) as task_graph_mock, + mock.patch( + "openfe.orchestration.exorcist_utils.exorcist.TaskStatusDB.from_filename", + return_value=fake_db, + ) as db_ctor, + ): + result = build_task_db_from_alchemical_network(network, warehouse) + + task_graph_mock.assert_called_once_with(network, warehouse) + db_ctor.assert_called_once_with(Path("tasks.db")) + fake_db.add_task_network.assert_called_once_with(fake_graph, 1) + assert result is fake_db + + +@pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) +def test_build_task_db_forwards_graph_and_max_tries(request, tmp_path, fixture): + network = request.getfixturevalue(fixture) + warehouse = mock.Mock() + fake_graph = nx.DiGraph() + fake_db = mock.Mock() + db_path = tmp_path / "custom_tasks.db" + + with ( + mock.patch( + "openfe.orchestration.exorcist_utils.alchemical_network_to_task_graph", + return_value=fake_graph, + ) as task_graph_mock, + mock.patch( + "openfe.orchestration.exorcist_utils.exorcist.TaskStatusDB.from_filename", + return_value=fake_db, + ) as db_ctor, + ): + result = build_task_db_from_alchemical_network( + network, + warehouse, + db_path=db_path, + max_tries=7, + ) + + task_graph_mock.assert_called_once_with(network, warehouse) + db_ctor.assert_called_once_with(db_path) + fake_db.add_task_network.assert_called_once_with(fake_graph, 7) + assert result is fake_db diff --git a/src/openfe/tests/orchestration/test_worker.py b/src/openfe/tests/orchestration/test_worker.py new file mode 100644 index 000000000..dca306d26 --- /dev/null +++ b/src/openfe/tests/orchestration/test_worker.py @@ -0,0 +1,154 @@ +from pathlib import Path +from unittest import mock + +import exorcist +import gufe +import networkx as nx +import pytest +from gufe.protocols.protocolunit import ProtocolUnit + +from openfe.orchestration import Worker +from openfe.orchestration.exorcist_utils import build_task_db_from_alchemical_network +from openfe.storage.warehouse import FileSystemWarehouse + + +def _result_store_files(warehouse: FileSystemWarehouse) -> set[str]: + result_root = Path(warehouse.result_store.root_dir) + return {str(path.relative_to(result_root)) for path in result_root.rglob("*") if path.is_file()} + + +def _contains_protocol_unit(value) -> bool: + if isinstance(value, ProtocolUnit): + return True + if isinstance(value, dict): + return any(_contains_protocol_unit(item) for item in value.values()) + if isinstance(value, list): + return any(_contains_protocol_unit(item) for item in value) + return False + + +def _get_dependency_free_unit(absolute_transformation): + for unit in absolute_transformation.create().protocol_units: + if not _contains_protocol_unit(unit.inputs): + return unit + raise ValueError("No dependency-free protocol unit found for execution test setup.") + + +@pytest.fixture +def worker_with_real_db(tmp_path, absolute_transformation): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + network = gufe.AlchemicalNetwork([absolute_transformation]) + db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + return worker, warehouse, db + + +@pytest.fixture +def worker_with_executable_task_db(tmp_path, absolute_transformation): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse = FileSystemWarehouse(str(warehouse_root)) + unit = _get_dependency_free_unit(absolute_transformation) + warehouse.store_task(unit) + + taskid = f"{absolute_transformation.key}:{unit.key}" + task_graph = nx.DiGraph() + task_graph.add_node(taskid) + + db = exorcist.TaskStatusDB.from_filename(db_path) + db.add_task_network(task_graph, 1) + + worker = Worker(warehouse=warehouse, task_db_path=db_path) + return worker, warehouse, db, unit + + +def test_get_task_uses_default_db_path_without_patching( + tmp_path, monkeypatch, absolute_transformation +): + monkeypatch.chdir(tmp_path) + warehouse = FileSystemWarehouse("warehouse") + db_path = Path("warehouse/tasks.db") + network = gufe.AlchemicalNetwork([absolute_transformation]) + db = build_task_db_from_alchemical_network(network, warehouse, db_path=db_path) + + worker = Worker(warehouse=warehouse) + loaded = worker._get_task() + + expected_keys = {task_row.taskid.split(":", maxsplit=1)[1] for task_row in db.get_all_tasks()} + assert worker.task_db_path == Path("./warehouse/tasks.db") + assert str(loaded.key) in expected_keys + + +def test_get_task_returns_task_with_canonical_protocol_unit_suffix(worker_with_real_db): + worker, warehouse, db = worker_with_real_db + + task_ids = [row.taskid for row in db.get_all_tasks()] + expected_protocol_unit_keys = {task_id.split(":", maxsplit=1)[1] for task_id in task_ids} + + loaded = worker._get_task() + reloaded = warehouse.load_task(loaded.key) + + assert str(loaded.key) in expected_protocol_unit_keys + assert loaded == reloaded + + +def test_execute_unit_stores_real_result(worker_with_executable_task_db, tmp_path): + worker, warehouse, db, _ = worker_with_executable_task_db + before = _result_store_files(warehouse) + + execution = worker.execute_unit(scratch=tmp_path / "scratch") + assert execution is not None + taskid, _ = execution + + after = _result_store_files(warehouse) + assert len(after) > len(before) + rows = list(db.get_all_tasks()) + status_by_taskid = {row.taskid: row.status for row in rows} + assert status_by_taskid[taskid] == exorcist.TaskStatus.COMPLETED.value + + +def test_execute_unit_propagates_execute_error_without_store( + worker_with_executable_task_db, tmp_path +): + worker, warehouse, db, unit = worker_with_executable_task_db + before = _result_store_files(warehouse) + taskid = list(db.get_all_tasks())[0].taskid + + with mock.patch.object( + type(unit), + "execute", + autospec=True, + side_effect=RuntimeError("unit execution failed"), + ): + with pytest.raises(RuntimeError, match="unit execution failed"): + worker.execute_unit(scratch=tmp_path / "scratch") + + after = _result_store_files(warehouse) + assert after == before + rows = list(db.get_all_tasks()) + status_by_taskid = {row.taskid: row.status for row in rows} + assert status_by_taskid[taskid] == exorcist.TaskStatus.TOO_MANY_RETRIES.value + + +def test_checkout_task_returns_none_when_no_available_tasks(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse_root.mkdir(parents=True, exist_ok=True) + warehouse = FileSystemWarehouse(str(warehouse_root)) + exorcist.TaskStatusDB.from_filename(db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + assert worker._checkout_task() is None + + +def test_execute_unit_returns_none_when_no_available_tasks(tmp_path): + warehouse_root = tmp_path / "warehouse" + db_path = warehouse_root / "tasks.db" + warehouse_root.mkdir(parents=True, exist_ok=True) + warehouse = FileSystemWarehouse(str(warehouse_root)) + exorcist.TaskStatusDB.from_filename(db_path) + worker = Worker(warehouse=warehouse, task_db_path=db_path) + + assert worker.execute_unit(scratch=tmp_path / "scratch") is None diff --git a/src/openfe/tests/storage/test_warehouse.py b/src/openfe/tests/storage/test_warehouse.py index 572769d70..d113cf1d7 100644 --- a/src/openfe/tests/storage/test_warehouse.py +++ b/src/openfe/tests/storage/test_warehouse.py @@ -19,18 +19,35 @@ class TestWarehouseBaseClass: def test_store_protocol_dag_result(self): pytest.skip("Not implemented yet") + @staticmethod + def _build_stores() -> WarehouseStores: + return WarehouseStores( + setup=MemoryStorage(), + result=MemoryStorage(), + shared=MemoryStorage(), + tasks=MemoryStorage(), + ) + + @staticmethod + def _get_protocol_unit(transformation): + dag = transformation.create() + return next(iter(dag.protocol_units)) + @staticmethod def _test_store_load_same_process( - obj, store_func_name, load_func_name, store_name: Literal["setup", "result"] + obj, + store_func_name, + load_func_name, + store_name: Literal["setup", "result", "tasks"], ): - setup_store = MemoryStorage() - result_store = MemoryStorage() - stores = WarehouseStores(setup=setup_store, result=result_store) + stores = TestWarehouseBaseClass._build_stores() client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert setup_store._data == {} - assert result_store._data == {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + assert stores["tasks"]._data == {} store_func(obj) store_under_test: MemoryStorage = stores[store_name] assert store_under_test._data != {} @@ -43,16 +60,16 @@ def _test_store_load_different_process( obj: GufeTokenizable, store_func_name, load_func_name, - store_name: Literal["setup", "result"], + store_name: Literal["setup", "result", "tasks"], ): - setup_store = MemoryStorage() - result_store = MemoryStorage() - stores = WarehouseStores(setup=setup_store, result=result_store) + stores = TestWarehouseBaseClass._build_stores() client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert setup_store._data == {} - assert result_store._data == {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + assert stores["tasks"]._data == {} store_func(obj) store_under_test: MemoryStorage = stores[store_name] assert store_under_test._data != {} @@ -65,6 +82,45 @@ def _test_store_load_different_process( assert reload == obj assert reload is not obj + def test_store_load_task_same_process(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + self._test_store_load_same_process(unit, "store_task", "load_task", "tasks") + + def test_store_load_task_different_process(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + self._test_store_load_different_process(unit, "store_task", "load_task", "tasks") + + def test_store_task_writes_to_tasks_store(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + client.store_task(unit) + + assert stores["tasks"]._data != {} + assert stores["setup"]._data == {} + assert stores["result"]._data == {} + assert stores["shared"]._data == {} + + def test_exists_finds_task_key(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + + client.store_task(unit) + + assert client.exists(unit.key) + + def test_load_task_returns_object(self, absolute_transformation): + unit = self._get_protocol_unit(absolute_transformation) + stores = self._build_stores() + client = WarehouseBaseClass(stores) + + client.store_task(unit) + loaded = client.load_task(unit.key) + + assert loaded is not None + assert isinstance(loaded, GufeTokenizable) + @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], @@ -164,6 +220,22 @@ def test_store_load_transformation_same_process(self, request, fixture): "load_setup_tokenizable", ) + def test_filesystemwarehouse_has_shared_and_tasks_stores(self, absolute_transformation): + unit = TestWarehouseBaseClass._get_protocol_unit(absolute_transformation) + + with tempfile.TemporaryDirectory() as tmpdir: + client = FileSystemWarehouse(tmpdir) + + assert "shared" in client.stores + assert "tasks" in client.stores + + client.stores["shared"].store_bytes("sentinel", b"shared-data") + with client.stores["shared"].load_stream("sentinel") as f: + assert f.read() == b"shared-data" + + client.store_task(unit) + assert client.exists(unit.key) + @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], diff --git a/src/openfecli/commands/worker.py b/src/openfecli/commands/worker.py new file mode 100644 index 000000000..dbc52b6b3 --- /dev/null +++ b/src/openfecli/commands/worker.py @@ -0,0 +1,83 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pathlib + +import click + +from openfecli import OFECommandPlugin +from openfecli.utils import print_duration, write + + +def _build_worker(warehouse_path: pathlib.Path, db_path: pathlib.Path): + from openfe.orchestration import Worker + from openfe.storage.warehouse import FileSystemWarehouse + + warehouse = FileSystemWarehouse(str(warehouse_path)) + return Worker(warehouse=warehouse, task_db_path=db_path) + + +def worker_main(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + db_path = warehouse_path / "tasks.db" + if not db_path.is_file(): + raise click.ClickException(f"Task database not found at: {db_path}") + + if scratch is None: + scratch = pathlib.Path.cwd() + + scratch.mkdir(parents=True, exist_ok=True) + + worker = _build_worker(warehouse_path, db_path) + + try: + execution = worker.execute_unit(scratch=scratch) + except Exception as exc: + raise click.ClickException(f"Task execution failed: {exc}") from exc + + if execution is None: + write("No available task in task graph.") + return None + + taskid, result = execution + if not result.ok(): + raise click.ClickException(f"Task '{taskid}' returned a failure result.") + + write(f"Completed task: {taskid}") + return result + + +@click.command("worker", short_help="Execute one available task from a filesystem warehouse") +@click.argument( + "warehouse_path", + type=click.Path( + exists=True, + readable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), +) +@click.option( + "--scratch", + "-s", + default=None, + type=click.Path( + writable=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), + help="Directory for scratch files. Defaults to current working directory.", +) +@print_duration +def worker(warehouse_path: pathlib.Path, scratch: pathlib.Path | None): + """ + Execute one available task from a warehouse task graph. + + The warehouse directory must contain a ``tasks.db`` task database and task + payloads under ``tasks/`` created via OpenFE orchestration setup. + """ + worker_main(warehouse_path=warehouse_path, scratch=scratch) + + +PLUGIN = OFECommandPlugin(command=worker, section="Quickrun Executor", requires_ofe=(0, 3)) diff --git a/src/openfecli/tests/commands/test_worker.py b/src/openfecli/tests/commands/test_worker.py new file mode 100644 index 000000000..6d3b55f7e --- /dev/null +++ b/src/openfecli/tests/commands/test_worker.py @@ -0,0 +1,107 @@ +from pathlib import Path +from unittest import mock + +from click.testing import CliRunner + +from openfecli.commands.worker import worker + + +class _SuccessfulResult: + def ok(self): + return True + + +class _FailedResult: + def ok(self): + return False + + +def test_worker_requires_task_database(): + runner = CliRunner() + with runner.isolated_filesystem(): + Path("warehouse").mkdir() + result = runner.invoke(worker, ["warehouse"]) + assert result.exit_code == 1 + assert "Task database not found at" in result.output + + +def test_worker_no_available_task_exits_zero(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = None + + with mock.patch( + "openfecli.commands.worker._build_worker", return_value=mock_worker + ) as build_worker: + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 0 + assert "No available task in task graph." in result.output + build_worker.assert_called_once_with(warehouse_path, warehouse_path / "tasks.db") + kwargs = mock_worker.execute_unit.call_args.kwargs + assert kwargs["scratch"] == Path.cwd() + + +def test_worker_executes_one_task_and_reports_completion(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _SuccessfulResult(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse", "--scratch", "scratch"]) + + assert result.exit_code == 0 + assert "Completed task: Transformation-abc:ProtocolUnit-def" in result.output + assert Path("scratch").is_dir() + kwargs = mock_worker.execute_unit.call_args.kwargs + assert kwargs["scratch"] == Path("scratch") + + +def test_worker_raises_when_result_is_failure(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.return_value = ( + "Transformation-abc:ProtocolUnit-def", + _FailedResult(), + ) + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert "returned a failure result" in result.output + + +def test_worker_raises_when_execution_throws(): + runner = CliRunner() + with runner.isolated_filesystem(): + warehouse_path = Path("warehouse") + warehouse_path.mkdir() + (warehouse_path / "tasks.db").touch() + + mock_worker = mock.Mock() + mock_worker.execute_unit.side_effect = RuntimeError("boom") + + with mock.patch("openfecli.commands.worker._build_worker", return_value=mock_worker): + result = runner.invoke(worker, ["warehouse"]) + + assert result.exit_code == 1 + assert "Task execution failed: boom" in result.output