Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ dependencies:
- pip:
- git+https://github.com/OpenFreeEnergy/gufe@main
- git+https://github.com/fatiando/pooch@main # related to https://github.com/fatiando/pooch/issues/502
- git+https://github.com/OpenFreeEnergy/exorcist@main
126 changes: 126 additions & 0 deletions src/openfe/orchestration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Task orchestration utilities backed by Exorcist and a warehouse."""

from dataclasses import dataclass
from pathlib import Path

from exorcist.taskdb import TaskStatusDB
from gufe.protocols.protocoldag import _pu_to_pur
from gufe.protocols.protocolunit import (
Context,
ProtocolUnit,
ProtocolUnitResult,
)
from gufe.storage.externalresource.filestorage import FileStorage
from gufe.tokenization import GufeKey

from openfe.storage.warehouse import FileSystemWarehouse

from .exorcist_utils import (
alchemical_network_to_task_graph,
build_task_db_from_alchemical_network,
)


@dataclass
class Worker:
"""Execute protocol units from an Exorcist task database.

Parameters
----------
warehouse : FileSystemWarehouse
Warehouse used to load queued tasks and store execution results.
task_db_path : pathlib.Path, default=Path("./warehouse/tasks.db")
Path to the Exorcist SQLite task database.
"""

warehouse: FileSystemWarehouse
task_db_path: Path = Path("./warehouse/tasks.db")

def _checkout_task(self) -> tuple[str, ProtocolUnit] | None:
"""Check out one available task and load its protocol unit.

Returns
-------
tuple[str, ProtocolUnit] or None
The checked-out task ID and corresponding protocol unit, or
``None`` if no task is currently available.
"""

db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path)
# The format for the taskid is "Transformation-<HASH>:ProtocolUnit-<HASH>"
taskid = db.check_out_task()
if taskid is None:
return None

_, protocol_unit_key = taskid.split(":", maxsplit=1)
unit = self.warehouse.load_task(GufeKey(protocol_unit_key))
return taskid, unit

def _get_task(self) -> ProtocolUnit:
"""Return the next available protocol unit.

Returns
-------
ProtocolUnit
A protocol unit loaded from the warehouse.

Raises
------
RuntimeError
Raised when no task is available in the task database.
"""

task = self._checkout_task()
if task is None:
raise RuntimeError("No AVAILABLE tasks found in the task database.")
_, unit = task
return unit

def execute_unit(self, scratch: Path) -> tuple[str, ProtocolUnitResult] | None:
"""Execute one checked-out protocol unit and persist its result.

Parameters
----------
scratch : pathlib.Path
Scratch directory passed to the protocol execution context.

Returns
-------
tuple[str, ProtocolUnitResult] or None
The task ID and execution result for the processed task, or
``None`` if no task is currently available.

Raises
------
Exception
Re-raises any exception thrown during protocol unit execution after
marking the task as failed.
"""

# 1. Get task/unit
task = self._checkout_task()
if task is None:
return None
taskid, unit = task
# 2. Constrcut the context
# NOTE: On changes to context, this can easily be replaced with external storage objects
# However, to satisfy the current work, we will use this implementation where we
# force the use of a FileSystemWarehouse and in turn can assert that an object is FileStorage.
shared_store: FileStorage = self.warehouse.stores["shared"]
shared_root_dir = shared_store.root_dir
ctx = Context(scratch, shared=shared_root_dir)
results: dict[GufeKey, ProtocolUnitResult] = {}
inputs = _pu_to_pur(unit.inputs, results)
db: TaskStatusDB = TaskStatusDB.from_filename(self.task_db_path)
# 3. Execute unit
try:
result = unit.execute(context=ctx, **inputs)
except Exception:
db.mark_task_completed(taskid, success=False)
raise

db.mark_task_completed(taskid, success=result.ok())
# 4. output result to warehouse
# TODO: we may need to end up handling namespacing on the warehouse side for tokenizables
self.warehouse.store_result_tokenizable(result)
return taskid, result
95 changes: 95 additions & 0 deletions src/openfe/orchestration/exorcist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Utilities for building Exorcist task graphs and task databases.

This module translates an :class:`gufe.AlchemicalNetwork` into Exorcist task
structures and can initialize an Exorcist task database from that graph.
"""

from pathlib import Path

import exorcist
import networkx as nx
from gufe import AlchemicalNetwork

from openfe.storage.warehouse import WarehouseBaseClass


def alchemical_network_to_task_graph(
alchemical_network: AlchemicalNetwork, warehouse: WarehouseBaseClass
) -> nx.DiGraph:
"""Build a global task DAG from an alchemical network.

Parameters
----------
alchemical_network : AlchemicalNetwork
Network containing transformations to execute.
warehouse : WarehouseBaseClass
Warehouse used to persist protocol units as tasks while the graph is
constructed.

Returns
-------
nx.DiGraph
A directed acyclic graph where each node is a task ID in the form
``"<transformation_key>:<protocol_unit_key>"`` and edges encode
protocol-unit dependencies.

Raises
------
ValueError
Raised if the assembled task graph is not acyclic.
"""

global_dag = nx.DiGraph()
for transformation in alchemical_network.edges:
dag = transformation.create()
for unit in dag.protocol_units:
node_id = f"{str(transformation.key)}:{str(unit.key)}"
global_dag.add_node(
node_id,
)
warehouse.store_task(unit)
for u, v in dag.graph.edges:
u_id = f"{str(transformation.key)}:{str(u.key)}"
v_id = f"{str(transformation.key)}:{str(v.key)}"
global_dag.add_edge(u_id, v_id)

if not nx.is_directed_acyclic_graph(global_dag):
raise ValueError("AlchemicalNetwork produced a task graph that is not a DAG.")

return global_dag


def build_task_db_from_alchemical_network(
alchemical_network: AlchemicalNetwork,
warehouse: WarehouseBaseClass,
db_path: Path | None = None,
max_tries: int = 1,
) -> exorcist.TaskStatusDB:
"""Create and populate a task database from an alchemical network.

Parameters
----------
alchemical_network : AlchemicalNetwork
Network containing transformations to convert into task records.
warehouse : WarehouseBaseClass
Warehouse used to persist protocol units while building the task DAG.
db_path : pathlib.Path or None, optional
Location of the SQLite-backed Exorcist database. If ``None``, defaults
to ``Path("tasks.db")`` in the current working directory.
max_tries : int, default=1
Maximum number of retries for each task before Exorcist marks it as
``TOO_MANY_RETRIES``.

Returns
-------
exorcist.TaskStatusDB
Initialized task database populated with graph nodes and dependency
edges derived from ``alchemical_network``.
"""
if db_path is None:
db_path = Path("tasks.db")

global_dag = alchemical_network_to_task_graph(alchemical_network, warehouse)
db = exorcist.TaskStatusDB.from_filename(db_path)
db.add_task_network(global_dag, max_tries)
return db
44 changes: 40 additions & 4 deletions src/openfe/storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
from typing import Literal, TypedDict

from gufe.protocols.protocoldag import ProtocolDAG
from gufe.protocols.protocolunit import ProtocolUnit
from gufe.storage.externalresource import ExternalStorage, FileStorage
from gufe.tokenization import (
JSON_HANDLER,
Expand Down Expand Up @@ -35,6 +37,8 @@ class WarehouseStores(TypedDict):

setup: ExternalStorage
result: ExternalStorage
shared: ExternalStorage
tasks: ExternalStorage


class WarehouseBaseClass:
Expand Down Expand Up @@ -83,6 +87,15 @@ def delete(self, store_name: Literal["setup", "result"], location: str):
store: ExternalStorage = self.stores[store_name]
store.delete(location)

def store_task(self, obj: ProtocolUnit):
self._store_gufe_tokenizable("tasks", obj)

def load_task(self, obj: GufeKey) -> ProtocolUnit:
unit = self._load_gufe_tokenizable(obj)
if not isinstance(unit, ProtocolUnit):
raise ValueError("Unable to load ProtocolUnit")
return unit

def store_setup_tokenizable(self, obj: GufeTokenizable):
"""Store a GufeTokenizable object in the setup store.

Expand Down Expand Up @@ -134,7 +147,7 @@ def load_result_tokenizable(self, obj: GufeKey) -> GufeTokenizable:
return self._load_gufe_tokenizable(gufe_key=obj)

def exists(self, key: GufeKey) -> bool:
"""Check if an object with the given key exists in any store.
"""Check if an object with the given key exists in any store that holds tokenizables.

Parameters
----------
Expand Down Expand Up @@ -171,7 +184,12 @@ def _get_store_for_key(self, key: GufeKey) -> ExternalStorage:
return self.stores[name]
raise ValueError(f"GufeKey {key} is not stored")

def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: GufeTokenizable):
def _store_gufe_tokenizable(
self,
store_name: Literal["setup", "result", "tasks"],
obj: GufeTokenizable,
name: str | None = None,
):
"""Store a GufeTokenizable object with deduplication.

Parameters
Expand All @@ -197,7 +215,10 @@ def _store_gufe_tokenizable(self, store_name: Literal["setup", "result"], obj: G
data = json.dumps(keyed_dict, cls=JSON_HANDLER.encoder, sort_keys=True).encode(
"utf-8"
)
target.store_bytes(gufe_key, data)
if name:
target.store_bytes(name, data)
else:
target.store_bytes(gufe_key, data)

def _load_gufe_tokenizable(self, gufe_key: GufeKey) -> GufeTokenizable:
"""Load a deduplicated object from a GufeKey.
Expand Down Expand Up @@ -293,6 +314,17 @@ def result_store(self):
"""
return self.stores["result"]

@property
def shared_store(self):
"""Get the shared store.

Returns
-------
ExternalStorage
The shared storage location
"""
return self.stores["shared"]


class FileSystemWarehouse(WarehouseBaseClass):
"""Warehouse implementation using local filesystem storage.
Expand All @@ -315,5 +347,9 @@ class FileSystemWarehouse(WarehouseBaseClass):
def __init__(self, root_dir: str = "warehouse"):
setup_store = FileStorage(f"{root_dir}/setup")
result_store = FileStorage(f"{root_dir}/result")
stores = WarehouseStores(setup=setup_store, result=result_store)
shared_store = FileStorage(f"{root_dir}/shared")
tasks_store = FileStorage(f"{root_dir}/tasks")
stores = WarehouseStores(
setup=setup_store, result=result_store, shared=shared_store, tasks=tasks_store
)
super().__init__(stores)
2 changes: 2 additions & 0 deletions src/openfe/tests/orchestration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
Loading
Loading