diff --git a/src/openfe/tests/storage/test_warehouse.py b/src/openfe/tests/storage/test_warehouse.py index 850ae172a..572769d70 100644 --- a/src/openfe/tests/storage/test_warehouse.py +++ b/src/openfe/tests/storage/test_warehouse.py @@ -1,6 +1,7 @@ import os import tempfile from pathlib import Path +from typing import Literal from unittest import mock import pytest @@ -19,28 +20,42 @@ def test_store_protocol_dag_result(self): pytest.skip("Not implemented yet") @staticmethod - def _test_store_load_same_process(obj, store_func_name, load_func_name): - store = MemoryStorage() - stores = WarehouseStores(setup=store) + def _test_store_load_same_process( + obj, store_func_name, load_func_name, store_name: Literal["setup", "result"] + ): + setup_store = MemoryStorage() + result_store = MemoryStorage() + stores = WarehouseStores(setup=setup_store, result=result_store) client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert store._data == {} + assert setup_store._data == {} + assert result_store._data == {} store_func(obj) - assert store._data != {} - reloaded = load_func(obj.key) + store_under_test: MemoryStorage = stores[store_name] + assert store_under_test._data != {} + reloaded: GufeTokenizable = load_func(obj.key) assert reloaded is obj + return reloaded, client @staticmethod - def _test_store_load_different_process(obj: GufeTokenizable, store_func_name, load_func_name): - store = MemoryStorage() - stores = WarehouseStores(setup=store) + def _test_store_load_different_process( + obj: GufeTokenizable, + store_func_name, + load_func_name, + store_name: Literal["setup", "result"], + ): + setup_store = MemoryStorage() + result_store = MemoryStorage() + stores = WarehouseStores(setup=setup_store, result=result_store) client = WarehouseBaseClass(stores) store_func = getattr(client, store_func_name) load_func = getattr(client, load_func_name) - assert store._data == {} + assert setup_store._data == {} + assert result_store._data == {} store_func(obj) - assert store._data != {} + store_under_test: MemoryStorage = stores[store_name] + assert store_under_test._data != {} # make it look like we have an empty cache, as if this was a # different process key = obj.key @@ -54,60 +69,56 @@ def _test_store_load_different_process(obj: GufeTokenizable, store_func_name, lo "fixture", ["absolute_transformation", "complex_equilibrium"], ) - def test_store_load_transformation_same_process(self, request, fixture): + @pytest.mark.parametrize("store", ["setup", "result"]) + def test_store_load_transformation_same_process(self, request, fixture, store): transformation = request.getfixturevalue(fixture) - self._test_store_load_same_process( - transformation, - "store_setup_tokenizable", - "load_setup_tokenizable", - ) + store_func_name = f"store_{store}_tokenizable" + load_func_name = f"load_{store}_tokenizable" + self._test_store_load_same_process(transformation, store_func_name, load_func_name, store) @pytest.mark.parametrize( "fixture", ["absolute_transformation", "complex_equilibrium"], ) - def test_store_load_transformation_different_process(self, request, fixture): + @pytest.mark.parametrize("store", ["setup", "result"]) + def test_store_load_transformation_different_process(self, request, fixture, store): transformation = request.getfixturevalue(fixture) + store_func_name = f"store_{store}_tokenizable" + load_func_name = f"load_{store}_tokenizable" self._test_store_load_different_process( - transformation, - "store_setup_tokenizable", - "load_setup_tokenizable", + transformation, store_func_name, load_func_name, store ) # @pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) - def test_store_load_network_same_process(self, request, fixture): + @pytest.mark.parametrize("store", ["setup", "result"]) + def test_store_load_network_same_process(self, request, fixture, store): network = request.getfixturevalue(fixture) assert isinstance(network, GufeTokenizable) - self._test_store_load_same_process( - network, "store_setup_tokenizable", "load_setup_tokenizable" - ) + store_func_name = f"store_{store}_tokenizable" + load_func_name = f"load_{store}_tokenizable" + self._test_store_load_same_process(network, store_func_name, load_func_name, store) - # @pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) - def test_store_load_network_different_process(self, request, fixture): + @pytest.mark.parametrize("store", ["setup", "result"]) + def test_store_load_network_different_process(self, request, fixture, store): network = request.getfixturevalue(fixture) - self._test_store_load_different_process( - network, "store_setup_tokenizable", "load_setup_tokenizable" - ) + assert isinstance(network, GufeTokenizable) + store_func_name = f"store_{store}_tokenizable" + load_func_name = f"load_{store}_tokenizable" + self._test_store_load_different_process(network, store_func_name, load_func_name, store) - # @pytest.mark.parametrize("fixture", ["benzene_variants_star_map"]) - def test_delete(self, request, fixture): - store = MemoryStorage() - stores = WarehouseStores(setup=store) - client = WarehouseBaseClass(stores) - + @pytest.mark.parametrize("store", ["setup", "result"]) + def test_delete(self, request, fixture, store): network = request.getfixturevalue(fixture) - assert store._data == {} - client.store_setup_tokenizable(network) - assert store._data != {} - key = network.key - loaded = client.load_setup_tokenizable(key) - assert loaded is network - assert client.setup_store.exists(key) - client.delete("setup", key) - assert not client.exists(key) + store_func_name = f"store_{store}_tokenizable" + load_func_name = f"load_{store}_tokenizable" + obj, client = self._test_store_load_same_process( + network, store_func_name, load_func_name, store + ) + client.delete(store, obj.key) + assert not client.exists(obj.key) class TestFileSystemWarehouse: