diff --git a/environment.yml b/environment.yml index c45d8102a..f74301b03 100644 --- a/environment.yml +++ b/environment.yml @@ -53,7 +53,7 @@ dependencies: # Control blas/openmp threads - threadpoolctl - pip: - - git+https://github.com/OpenFreeEnergy/gufe@main + - git+https://github.com/OpenFreeEnergy/gufe@restart_execute - run_constrained: # drop this pin when handled upstream in espaloma-feedstock - smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109 diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py index 5d815c713..53574f972 100644 --- a/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py @@ -79,12 +79,12 @@ def patcher(): yield -def test_gather(benzene_complex_dag, patcher, tmpdir): +def test_gather(benzene_complex_dag, patcher, tmp_path): # check that .gather behaves as expected dagres = gufe.protocols.execute_DAG( benzene_complex_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py index 0cb2d2d25..619b199d1 100644 --- a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py @@ -99,12 +99,12 @@ def patcher(): yield -def test_gather(benzene_solvation_dag, patcher, tmpdir): +def test_gather(benzene_solvation_dag, patcher, tmp_path): # check that .gather behaves as expected dagres = gufe.protocols.execute_DAG( benzene_solvation_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py b/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py index 60c7e8c47..b8e3153ff 100644 --- a/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py +++ b/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py @@ -508,7 +508,7 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): assert len(repeats) == 3 -def test_gather(solvent_protocol_dag, tmpdir): +def test_gather(solvent_protocol_dag, tmp_path): # check .gather behaves as expected with mock.patch( "openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run", @@ -519,8 +519,8 @@ def test_gather(solvent_protocol_dag, tmpdir): ): dagres = gufe.protocols.execute_DAG( solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index bd7a1f72f..c632505ea 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -1225,7 +1225,7 @@ def test_unit_tagging(solvent_protocol_dag, tmpdir): assert len(setup_results) == len(sim_results) == len(analysis_results) == 3 -def test_gather(solvent_protocol_dag, tmpdir): +def test_gather(solvent_protocol_dag, tmp_path): # check .gather behaves as expected with ( mock.patch( @@ -1263,8 +1263,8 @@ def test_gather(solvent_protocol_dag, tmpdir): ): dagres = gufe.protocols.execute_DAG( solvent_protocol_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py index e5e4a9f91..4a22aebcf 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py @@ -1295,7 +1295,7 @@ def test_unit_tagging(benzene_toluene_dag, tmpdir): assert len(complex_repeats) == len(solv_repeats) == 2 -def test_gather(benzene_toluene_dag, tmpdir): +def test_gather(benzene_toluene_dag, tmp_path): # check that .gather behaves as expected with ( mock.patch( @@ -1339,8 +1339,8 @@ def test_gather(benzene_toluene_dag, tmpdir): ): dagres = gufe.protocols.execute_DAG( benzene_toluene_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, + shared_basedir=tmp_path, + scratch_basedir=tmp_path, keep_shared=True, ) diff --git a/src/openfecli/commands/quickrun.py b/src/openfecli/commands/quickrun.py index f34410d69..cf809605c 100644 --- a/src/openfecli/commands/quickrun.py +++ b/src/openfecli/commands/quickrun.py @@ -51,7 +51,9 @@ def quickrun(transformation, work_dir, output): import logging import os import sys + from json import JSONDecodeError + from gufe import ProtocolDAG from gufe.protocols.protocoldag import execute_DAG from gufe.tokenization import JSON_HANDLER from gufe.transformations.transformation import Transformation @@ -94,13 +96,26 @@ def quickrun(transformation, work_dir, output): else: output.parent.mkdir(exist_ok=True, parents=True) - write("Planning simulations for this edge...") - dag = trans.create() + # Attempt to either deserialize or freshly create DAG + if (work_dir / "protocol_dag.json").is_file(): + write("Attempting to recover edge simulations from file") + try: + dag = ProtocolDAG.from_json(work_dir / "protocol_dag.json") + except JSONDecodeError: + errmsg = "Recovery failed, please clean workdir before continuing" + raise click.ClickException(errmsg) + else: + # Create the DAG instead and then serialize for later resuming + write("Planning simulations for this edge...") + dag = trans.create() + dag.to_json(work_dir / "protocol_dag.json") + write("Starting the simulations for this edge...") dagresult = execute_DAG( dag, shared_basedir=work_dir, scratch_basedir=work_dir, + unitresults_basedir=work_dir, keep_shared=True, raise_error=False, n_retries=2,