diff --git a/openfecli/commands/gather.py b/openfecli/commands/gather.py index cbb2b93f6..655907a98 100644 --- a/openfecli/commands/gather.py +++ b/openfecli/commands/gather.py @@ -11,6 +11,7 @@ from openfecli import OFECommandPlugin from openfecli.clicktypes import HyphenAwareChoice +from openfecli.commands.quickrun import _QuickrunResult FAIL_STR = "Error" # string used to indicate a failed run in output tables. @@ -181,30 +182,7 @@ def is_results_json(fpath: os.PathLike | str) -> bool: return "estimate" in open(fpath, "r").read(20) -def load_json(fpath: os.PathLike | str) -> dict: - """Load a JSON file containing a gufe object. - - Parameters - ---------- - fpath : os.PathLike | str - The path to a gufe-serialized JSON. - - - Returns - ------- - dict - A dict containing data from the results JSON. - - """ - # TODO: move this function to openfe/utils - import json - - from gufe.tokenization import JSON_HANDLER - - return json.load(open(fpath, "r"), cls=JSON_HANDLER.decoder) - - -def _get_names(result: dict) -> tuple[str, str]: +def _get_names(result: _QuickrunResult) -> tuple[str, str]: """Get the ligand names from a unit's results data. Parameters @@ -219,7 +197,7 @@ def _get_names(result: dict) -> tuple[str, str]: """ # TODO: I don't like this [0][0] indexing, but I can't think of a better way currently - protocol_data = list(result["protocol_result"]["data"].values())[0][0] + protocol_data = list(result.protocol_result["data"].values())[0][0] try: name_A = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentA"][ "molprops" @@ -234,10 +212,10 @@ def _get_names(result: dict) -> tuple[str, str]: return str(name_A), str(name_B) -def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]: +def _get_type(result: _QuickrunResult) -> Literal["vacuum", "solvent", "complex"]: """Determine the simulation type based on the component types.""" - protocol_data = list(result["protocol_result"]["data"].values())[0][0] + protocol_data = list(result.protocol_result["data"].values())[0][0] try: component_types = [ x["__module__"] @@ -270,7 +248,7 @@ def _legacy_get_type(res_fn: os.PathLike | str) -> Literal["vacuum", "solvent", def _get_result_id( - result: dict, result_fn: os.PathLike | str + result: _QuickrunResult, result_fn: os.PathLike | str ) -> tuple[tuple[str, str], Literal["vacuum", "solvent", "complex"]]: """Extract the name and simulation type from a results dict. @@ -296,7 +274,9 @@ def _get_result_id( return (ligA, ligB), simtype -def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]: +def _load_valid_result_json( + fpath: os.PathLike | str, +) -> tuple[tuple | None, _QuickrunResult | None]: """Load the data from a results JSON into a dict. Parameters @@ -311,25 +291,25 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic or None if the JSON file is invalid or missing. """ - # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function # for now though, it's not the bottleneck on performance - result = load_json(fpath) + result = _QuickrunResult.from_json(fpath) + try: result_id = _get_result_id(result, fpath) except (ValueError, IndexError): click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip return None, None - if result["estimate"] is None: + if result.estimate is None: click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None - if result["uncertainty"] is None: + if result.uncertainty is None: click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None - if result["unit_results"] == {}: + if result.unit_results == {}: click.secho(f"{fpath}: No 'unit_results' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None - if all("exception" in u for u in result["unit_results"].values()): + if all("exception" in u for u in result.unit_results.values()): click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None @@ -682,7 +662,7 @@ def _get_legs_from_result_jsons( v[0]["outputs"]["unit_estimate"], v[0]["outputs"]["unit_estimate_error"], ) - for v in result["protocol_result"]["data"].values() + for v in result.protocol_result["data"].values() ] legs[names][simtype].append(parsed_raw_data) else: @@ -692,7 +672,7 @@ def _get_legs_from_result_jsons( else: dGs = [ v[0]["outputs"]["unit_estimate"] - for v in result["protocol_result"]["data"].values() + for v in result.protocol_result["data"].values() ] legs[names][simtype].extend(dGs) return legs diff --git a/openfecli/commands/gather_abfe.py b/openfecli/commands/gather_abfe.py index a541bfc52..6b836585f 100644 --- a/openfecli/commands/gather_abfe.py +++ b/openfecli/commands/gather_abfe.py @@ -13,12 +13,12 @@ from openfecli.commands.gather import ( _collect_result_jsons, format_df_with_precision, - load_json, rich_print_to_stdout, ) +from openfecli.quickrun_result import _QuickrunResult -def _get_name(result: dict) -> str: +def _get_name(result: _QuickrunResult) -> str: """Get the ligand name from a unit's results data. Parameters @@ -32,13 +32,15 @@ def _get_name(result: dict) -> str: Ligand name corresponding to the results. """ - solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0] + solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0] name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"] return str(name) -def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]: +def _load_valid_result_json( + fpath: os.PathLike | str, +) -> tuple[tuple | None, _QuickrunResult | None]: """Load the data from a results JSON into a dict. Parameters @@ -63,19 +65,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function # for now though, it's not the bottleneck on performance - result = load_json(fpath) + result = _QuickrunResult.from_json(fpath) try: names = _get_name(result) except (ValueError, IndexError): click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip return None, None - if result["estimate"] is None: + if result.estimate is None: click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if result["uncertainty"] is None: + if result.uncertainty is None: click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if all("exception" in u for u in result["unit_results"].values()): + if all("exception" in u for u in result.unit_results.values()): click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None return names, result @@ -110,17 +112,17 @@ def _get_legs_from_result_jsons( if name is None: # this means it couldn't find name and/or simtype continue - dgs[name]["overall"].append([result["estimate"], result["uncertainty"]]) - proto_key = [k for k in result["unit_results"].keys() if k.startswith("ProtocolUnitResult")] + dgs[name]["overall"].append([result.estimate, result.uncertainty]) + proto_key = [k for k in result.unit_results.keys() if k.startswith("ProtocolUnitResult")] for p in proto_key: - if "unit_estimate" in result["unit_results"][p]["outputs"]: - simtype = result["unit_results"][p]["outputs"]["simtype"] - dg = result["unit_results"][p]["outputs"]["unit_estimate"] - dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"] + if "unit_estimate" in result.unit_results[p]["outputs"]: + simtype = result.unit_results[p]["outputs"]["simtype"] + dg = result.unit_results[p]["outputs"]["unit_estimate"] + dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"] dgs[name][simtype].append([dg, dg_error]) - if "standard_state_correction" in result["unit_results"][p]["outputs"]: - corr = result["unit_results"][p]["outputs"]["standard_state_correction"] + if "standard_state_correction" in result.unit_results[p]["outputs"]: + corr = result.unit_results[p]["outputs"]["standard_state_correction"] dgs[name]["standard_state_correction"].append([corr, 0 * unit.kilocalorie_per_mole]) else: continue diff --git a/openfecli/commands/gather_septop.py b/openfecli/commands/gather_septop.py index 9a027a757..3fae2a171 100644 --- a/openfecli/commands/gather_septop.py +++ b/openfecli/commands/gather_septop.py @@ -14,9 +14,9 @@ from openfecli.commands.gather import ( _collect_result_jsons, format_df_with_precision, - load_json, rich_print_to_stdout, ) +from openfecli.quickrun_result import _QuickrunResult def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]: @@ -44,19 +44,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function # for now though, it's not the bottleneck on performance - result = load_json(fpath) + result = _QuickrunResult.from_json(fpath) try: names = _get_names(result) except (ValueError, IndexError): click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip return None, None - if result["estimate"] is None: + if result.estimate is None: click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if result["uncertainty"] is None: + if result.uncertainty is None: click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if all("exception" in u for u in result["unit_results"].values()): + if all("exception" in u for u in result.unit_results.values()): click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None return names, result @@ -91,22 +91,22 @@ def _get_legs_from_result_jsons( if names is None: # this means it couldn't find names and/or simtype continue - ddgs[names]["overall"].append([result["estimate"], result["uncertainty"]]) + ddgs[names]["overall"].append([result.estimate, result.uncertainty]) proto_key = [ k - for k in result["unit_results"].keys() + for k in result.unit_results.keys() if k.startswith("ProtocolUnitResult") ] # fmt: skip for p in proto_key: - if "unit_estimate" in result["unit_results"][p]["outputs"]: - simtype = result["unit_results"][p]["outputs"]["simtype"] - dg = result["unit_results"][p]["outputs"]["unit_estimate"] - dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"] + if "unit_estimate" in result.unit_results[p]["outputs"]: + simtype = result.unit_results[p]["outputs"]["simtype"] + dg = result.unit_results[p]["outputs"]["unit_estimate"] + dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"] ddgs[names][simtype].append([dg, dg_error]) - elif "standard_state_correction_A" in result["unit_results"][p]["outputs"]: - corr_A = result["unit_results"][p]["outputs"]["standard_state_correction_A"] - corr_B = result["unit_results"][p]["outputs"]["standard_state_correction_B"] + elif "standard_state_correction_A" in result.unit_results[p]["outputs"]: + corr_A = result.unit_results[p]["outputs"]["standard_state_correction_A"] + corr_B = result.unit_results[p]["outputs"]["standard_state_correction_B"] ddgs[names]["standard_state_correction_A"].append( [corr_A, 0 * unit.kilocalorie_per_mole] ) @@ -119,7 +119,7 @@ def _get_legs_from_result_jsons( return ddgs -def _get_names(result: dict) -> tuple[str, str]: +def _get_names(result: _QuickrunResult) -> tuple[str, str]: """Get the ligand names from a unit's results data. Parameters @@ -133,7 +133,7 @@ def _get_names(result: dict) -> tuple[str, str]: Ligand names corresponding to the results. """ - solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0] + solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0] name_A = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"] name_B = solvent_data["inputs"]["alchemical_components"]["stateB"][0]["molprops"]["ofe-name"] diff --git a/openfecli/commands/quickrun.py b/openfecli/commands/quickrun.py index f34410d69..f52df614e 100644 --- a/openfecli/commands/quickrun.py +++ b/openfecli/commands/quickrun.py @@ -1,12 +1,12 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import json import pathlib import click from openfecli import OFECommandPlugin +from openfecli.quickrun_result import _QuickrunResult from openfecli.utils import configure_logger, print_duration, write @@ -114,17 +114,14 @@ def quickrun(transformation, work_dir, output): else: estimate = uncertainty = None # for output file - out_dict = { - "estimate": estimate, - "uncertainty": uncertainty, - "protocol_result": prot_result.to_dict(), - "unit_results": { - unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results - }, - } - - with open(output, mode="w") as outf: - json.dump(out_dict, outf, cls=JSON_HANDLER.encoder) + quickrun_result = _QuickrunResult( + estimate=estimate, + uncertainty=uncertainty, + protocol_result=prot_result.to_dict(), + unit_results={unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results}, + ) + + quickrun_result.to_json(output) write(f"Here is the result:\n\tdG = {estimate} ± {uncertainty}\n") write("") diff --git a/openfecli/quickrun_result.py b/openfecli/quickrun_result.py new file mode 100644 index 000000000..e3c79118b --- /dev/null +++ b/openfecli/quickrun_result.py @@ -0,0 +1,51 @@ +import json +from dataclasses import asdict, dataclass +from os import PathLike +from typing import Any, Self + +from gufe.tokenization import JSON_HANDLER +from openff.units import Quantity + + +@dataclass +class _QuickrunResult: + """ + Class for storing protocol result data along with useful metadata. + Could ProtocolResults store this data alongside ``n_protocol_dag_results``? + """ + + estimate: Quantity + uncertainty: Quantity + protocol_result: dict[str, Any] + unit_results: dict[int, dict] + + def to_json(self, filepath) -> None: + with open(filepath, mode="w") as file: + json.dump(asdict(self), file, cls=JSON_HANDLER.encoder) + + @classmethod + def from_json(cls, file: PathLike | None, content: str | None = None) -> Self: + """Load a JSON file containing a gufe object. + + Parameters + ---------- + fpath : os.PathLike | str + The path to a results JSON generated by ``openfe quickrun``. + + + Returns + ------- + _QuickrunResult + A _QuickrunResult instance containing the data from ``file, co``. + + """ + # similar to gufe.tokenization.from_json + if content is not None and file is not None: + raise ValueError("Cannot specify both `content` and `file`; only one input allowed") + elif content is None and file is None: + raise ValueError("Must specify either `content` and `file` for JSON input") + if file: + data = json.load(open(file, "r"), cls=JSON_HANDLER.decoder) + if content: + data = json.loads(content, cls=JSON_HANDLER.decoder) + return cls(**data) diff --git a/openfecli/tests/commands/test_gather.py b/openfecli/tests/commands/test_gather.py index 1cae95aa8..60d8ccc36 100644 --- a/openfecli/tests/commands/test_gather.py +++ b/openfecli/tests/commands/test_gather.py @@ -17,6 +17,7 @@ ) from openfecli.commands.gather_abfe import gather_abfe from openfecli.commands.gather_septop import gather_septop +from openfecli.commands.quickrun import _QuickrunResult from ..conftest import HAS_INTERNET from ..utils import assert_click_success @@ -68,62 +69,88 @@ def test_get_column(val, col): assert _get_column(val) == col +@pytest.fixture +def min_valid_quickrun_result(min_result_json): + return _QuickrunResult(**min_result_json) + + class TestResultLoading: - def test_minimal_valid_results(self, capsys, min_result_json): - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + def test_minimal_valid_results(self, capsys, min_valid_quickrun_result): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() - assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), min_result_json) + assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), min_valid_quickrun_result) assert captured.err == "" - def test_skip_missing_unit_result(self, capsys, min_result_json): - min_result_json["unit_results"] = {} + def test_skip_missing_unit_result(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.unit_results = {} - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "No 'unit_results' found" in captured.err - def test_skip_missing_estimate(self, capsys, min_result_json): - min_result_json["estimate"] = None + def test_skip_missing_estimate(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.estimate = None - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "No 'estimate' found" in captured.err - def test_skip_missing_uncertainty(self, capsys, min_result_json): - min_result_json["uncertainty"] = None + def test_skip_missing_uncertainty(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.uncertainty = None - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "No 'uncertainty' found" in captured.err - def test_skip_all_failed_runs(self, capsys, min_result_json): - del min_result_json["unit_results"]["ProtocolUnitResult-e85"] - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + def test_skip_all_failed_runs(self, capsys, min_valid_quickrun_result): + del min_valid_quickrun_result.unit_results["ProtocolUnitResult-e85"] + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "Exception found in all" in captured.err - def test_missing_pr_data(self, capsys, min_result_json): - min_result_json["protocol_result"]["data"] = {} - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + def test_missing_pr_data(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.protocol_result["data"] = {} + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == (None, None) assert "Missing ligand names and/or simulation type. Skipping" in captured.err - def test_get_legs_from_result_jsons(self, capsys, min_result_json): + def test_get_legs_from_result_jsons(self, capsys, min_valid_quickrun_result): """Test that exceptions are handled correctly at the _get_legs_from_results_json level.""" - min_result_json["protocol_result"]["data"] = {} + min_valid_quickrun_result.protocol_result["data"] = {} - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _get_legs_from_result_jsons(result_fns=[""], report="dg") captured = capsys.readouterr() assert result == {}