Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 17 additions & 37 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from openfecli import OFECommandPlugin
from openfecli.clicktypes import HyphenAwareChoice
from openfecli.commands.quickrun import _QuickrunResult

FAIL_STR = "Error" # string used to indicate a failed run in output tables.

Expand Down Expand Up @@ -181,30 +182,7 @@ def is_results_json(fpath: os.PathLike | str) -> bool:
return "estimate" in open(fpath, "r").read(20)


def load_json(fpath: os.PathLike | str) -> dict:
"""Load a JSON file containing a gufe object.

Parameters
----------
fpath : os.PathLike | str
The path to a gufe-serialized JSON.


Returns
-------
dict
A dict containing data from the results JSON.

"""
# TODO: move this function to openfe/utils
import json

from gufe.tokenization import JSON_HANDLER

return json.load(open(fpath, "r"), cls=JSON_HANDLER.decoder)


def _get_names(result: dict) -> tuple[str, str]:
def _get_names(result: _QuickrunResult) -> tuple[str, str]:
"""Get the ligand names from a unit's results data.

Parameters
Expand All @@ -219,7 +197,7 @@ def _get_names(result: dict) -> tuple[str, str]:
"""

# TODO: I don't like this [0][0] indexing, but I can't think of a better way currently
protocol_data = list(result["protocol_result"]["data"].values())[0][0]
protocol_data = list(result.protocol_result["data"].values())[0][0]
try:
name_A = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentA"][
"molprops"
Expand All @@ -234,10 +212,10 @@ def _get_names(result: dict) -> tuple[str, str]:
return str(name_A), str(name_B)


def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]:
def _get_type(result: _QuickrunResult) -> Literal["vacuum", "solvent", "complex"]:
"""Determine the simulation type based on the component types."""

protocol_data = list(result["protocol_result"]["data"].values())[0][0]
protocol_data = list(result.protocol_result["data"].values())[0][0]
try:
component_types = [
x["__module__"]
Expand Down Expand Up @@ -270,7 +248,7 @@ def _legacy_get_type(res_fn: os.PathLike | str) -> Literal["vacuum", "solvent",


def _get_result_id(
result: dict, result_fn: os.PathLike | str
result: _QuickrunResult, result_fn: os.PathLike | str
) -> tuple[tuple[str, str], Literal["vacuum", "solvent", "complex"]]:
"""Extract the name and simulation type from a results dict.

Expand All @@ -296,7 +274,9 @@ def _get_result_id(
return (ligA, ligB), simtype


def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]:
def _load_valid_result_json(
fpath: os.PathLike | str,
) -> tuple[tuple | None, _QuickrunResult | None]:
"""Load the data from a results JSON into a dict.

Parameters
Expand All @@ -311,25 +291,25 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic
or None if the JSON file is invalid or missing.

"""

# TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
# for now though, it's not the bottleneck on performance
result = load_json(fpath)
result = _QuickrunResult.from_json(fpath)

try:
result_id = _get_result_id(result, fpath)
except (ValueError, IndexError):
click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip
return None, None
if result["estimate"] is None:
if result.estimate is None:
click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None
if result["uncertainty"] is None:
if result.uncertainty is None:
click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None
if result["unit_results"] == {}:
if result.unit_results == {}:
click.secho(f"{fpath}: No 'unit_results' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None
if all("exception" in u for u in result["unit_results"].values()):
if all("exception" in u for u in result.unit_results.values()):
click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None

Expand Down Expand Up @@ -682,7 +662,7 @@ def _get_legs_from_result_jsons(
v[0]["outputs"]["unit_estimate"],
v[0]["outputs"]["unit_estimate_error"],
)
for v in result["protocol_result"]["data"].values()
for v in result.protocol_result["data"].values()
]
legs[names][simtype].append(parsed_raw_data)
else:
Expand All @@ -692,7 +672,7 @@ def _get_legs_from_result_jsons(
else:
dGs = [
v[0]["outputs"]["unit_estimate"]
for v in result["protocol_result"]["data"].values()
for v in result.protocol_result["data"].values()
]
legs[names][simtype].extend(dGs)
return legs
Expand Down
34 changes: 18 additions & 16 deletions openfecli/commands/gather_abfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from openfecli.commands.gather import (
_collect_result_jsons,
format_df_with_precision,
load_json,
rich_print_to_stdout,
)
from openfecli.quickrun_result import _QuickrunResult


def _get_name(result: dict) -> str:
def _get_name(result: _QuickrunResult) -> str:
"""Get the ligand name from a unit's results data.

Parameters
Expand All @@ -32,13 +32,15 @@ def _get_name(result: dict) -> str:
Ligand name corresponding to the results.
"""

solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0]
name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]

return str(name)


def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]:
def _load_valid_result_json(
fpath: os.PathLike | str,
) -> tuple[tuple | None, _QuickrunResult | None]:
"""Load the data from a results JSON into a dict.

Parameters
Expand All @@ -63,19 +65,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic

# TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
# for now though, it's not the bottleneck on performance
result = load_json(fpath)
result = _QuickrunResult.from_json(fpath)
try:
names = _get_name(result)
except (ValueError, IndexError):
click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip
return None, None
if result["estimate"] is None:
if result.estimate is None:
click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if result["uncertainty"] is None:
if result.uncertainty is None:
click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if all("exception" in u for u in result["unit_results"].values()):
if all("exception" in u for u in result.unit_results.values()):
click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
return names, result
Expand Down Expand Up @@ -110,17 +112,17 @@ def _get_legs_from_result_jsons(
if name is None: # this means it couldn't find name and/or simtype
continue

dgs[name]["overall"].append([result["estimate"], result["uncertainty"]])
proto_key = [k for k in result["unit_results"].keys() if k.startswith("ProtocolUnitResult")]
dgs[name]["overall"].append([result.estimate, result.uncertainty])
proto_key = [k for k in result.unit_results.keys() if k.startswith("ProtocolUnitResult")]
for p in proto_key:
if "unit_estimate" in result["unit_results"][p]["outputs"]:
simtype = result["unit_results"][p]["outputs"]["simtype"]
dg = result["unit_results"][p]["outputs"]["unit_estimate"]
dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"]
if "unit_estimate" in result.unit_results[p]["outputs"]:
simtype = result.unit_results[p]["outputs"]["simtype"]
dg = result.unit_results[p]["outputs"]["unit_estimate"]
dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"]

dgs[name][simtype].append([dg, dg_error])
if "standard_state_correction" in result["unit_results"][p]["outputs"]:
corr = result["unit_results"][p]["outputs"]["standard_state_correction"]
if "standard_state_correction" in result.unit_results[p]["outputs"]:
corr = result.unit_results[p]["outputs"]["standard_state_correction"]
dgs[name]["standard_state_correction"].append([corr, 0 * unit.kilocalorie_per_mole])
else:
continue
Expand Down
32 changes: 16 additions & 16 deletions openfecli/commands/gather_septop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from openfecli.commands.gather import (
_collect_result_jsons,
format_df_with_precision,
load_json,
rich_print_to_stdout,
)
from openfecli.quickrun_result import _QuickrunResult


def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]:
Expand Down Expand Up @@ -44,19 +44,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic

# TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
# for now though, it's not the bottleneck on performance
result = load_json(fpath)
result = _QuickrunResult.from_json(fpath)
try:
names = _get_names(result)
except (ValueError, IndexError):
click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip
return None, None
if result["estimate"] is None:
if result.estimate is None:
click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if result["uncertainty"] is None:
if result.uncertainty is None:
click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if all("exception" in u for u in result["unit_results"].values()):
if all("exception" in u for u in result.unit_results.values()):
click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
return names, result
Expand Down Expand Up @@ -91,22 +91,22 @@ def _get_legs_from_result_jsons(
if names is None: # this means it couldn't find names and/or simtype
continue

ddgs[names]["overall"].append([result["estimate"], result["uncertainty"]])
ddgs[names]["overall"].append([result.estimate, result.uncertainty])
proto_key = [
k
for k in result["unit_results"].keys()
for k in result.unit_results.keys()
if k.startswith("ProtocolUnitResult")
] # fmt: skip
for p in proto_key:
if "unit_estimate" in result["unit_results"][p]["outputs"]:
simtype = result["unit_results"][p]["outputs"]["simtype"]
dg = result["unit_results"][p]["outputs"]["unit_estimate"]
dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"]
if "unit_estimate" in result.unit_results[p]["outputs"]:
simtype = result.unit_results[p]["outputs"]["simtype"]
dg = result.unit_results[p]["outputs"]["unit_estimate"]
dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"]

ddgs[names][simtype].append([dg, dg_error])
elif "standard_state_correction_A" in result["unit_results"][p]["outputs"]:
corr_A = result["unit_results"][p]["outputs"]["standard_state_correction_A"]
corr_B = result["unit_results"][p]["outputs"]["standard_state_correction_B"]
elif "standard_state_correction_A" in result.unit_results[p]["outputs"]:
corr_A = result.unit_results[p]["outputs"]["standard_state_correction_A"]
corr_B = result.unit_results[p]["outputs"]["standard_state_correction_B"]
ddgs[names]["standard_state_correction_A"].append(
[corr_A, 0 * unit.kilocalorie_per_mole]
)
Expand All @@ -119,7 +119,7 @@ def _get_legs_from_result_jsons(
return ddgs


def _get_names(result: dict) -> tuple[str, str]:
def _get_names(result: _QuickrunResult) -> tuple[str, str]:
"""Get the ligand names from a unit's results data.

Parameters
Expand All @@ -133,7 +133,7 @@ def _get_names(result: dict) -> tuple[str, str]:
Ligand names corresponding to the results.
"""

solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0]

name_A = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
name_B = solvent_data["inputs"]["alchemical_components"]["stateB"][0]["molprops"]["ofe-name"]
Expand Down
21 changes: 9 additions & 12 deletions openfecli/commands/quickrun.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe

import json
import pathlib

import click

from openfecli import OFECommandPlugin
from openfecli.quickrun_result import _QuickrunResult
from openfecli.utils import configure_logger, print_duration, write


Expand Down Expand Up @@ -114,17 +114,14 @@ def quickrun(transformation, work_dir, output):
else:
estimate = uncertainty = None # for output file

out_dict = {
"estimate": estimate,
"uncertainty": uncertainty,
"protocol_result": prot_result.to_dict(),
"unit_results": {
unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results
},
}

with open(output, mode="w") as outf:
json.dump(out_dict, outf, cls=JSON_HANDLER.encoder)
quickrun_result = _QuickrunResult(
estimate=estimate,
uncertainty=uncertainty,
protocol_result=prot_result.to_dict(),
unit_results={unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results},
)

quickrun_result.to_json(output)

write(f"Here is the result:\n\tdG = {estimate} ± {uncertainty}\n")
write("")
Expand Down
51 changes: 51 additions & 0 deletions openfecli/quickrun_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import json
from dataclasses import asdict, dataclass
from os import PathLike
from typing import Any, Self

from gufe.tokenization import JSON_HANDLER
from openff.units import Quantity


@dataclass
class _QuickrunResult:
"""
Class for storing protocol result data along with useful metadata.
Could ProtocolResults store this data alongside ``n_protocol_dag_results``?
"""

estimate: Quantity
uncertainty: Quantity
protocol_result: dict[str, Any]
unit_results: dict[int, dict]

def to_json(self, filepath) -> None:
with open(filepath, mode="w") as file:
json.dump(asdict(self), file, cls=JSON_HANDLER.encoder)

@classmethod
def from_json(cls, file: PathLike | None, content: str | None = None) -> Self:
"""Load a JSON file containing a gufe object.

Parameters
----------
fpath : os.PathLike | str
The path to a results JSON generated by ``openfe quickrun``.


Returns
-------
_QuickrunResult
A _QuickrunResult instance containing the data from ``file, co``.

"""
# similar to gufe.tokenization.from_json
if content is not None and file is not None:
raise ValueError("Cannot specify both `content` and `file`; only one input allowed")
elif content is None and file is None:
raise ValueError("Must specify either `content` and `file` for JSON input")
if file:
data = json.load(open(file, "r"), cls=JSON_HANDLER.decoder)
if content:
data = json.loads(content, cls=JSON_HANDLER.decoder)
return cls(**data)
Loading
Loading