Skip to content

Commit 171ad73

Browse files
[Feature] Add option to overwrite cache_key (#676)
* [Feature] Add option to overwrite cache_key * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * recursively iterate over folders * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Bug] Fix folder generation in dump() function * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f259b0a commit 171ad73

File tree

7 files changed

+84
-18
lines changed

7 files changed

+84
-18
lines changed

executorlib/standalone/cache.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@
1313
}
1414

1515

16+
def get_cache_files(cache_directory: str) -> list[str]:
17+
"""
18+
Recursively find all HDF5 files in the cache_directory which contain outputs.
19+
20+
Args:
21+
cache_directory (str): The directory to store cache files.
22+
23+
Returns:
24+
list[str]: List of HDF5 file in the cache directory which contain outputs.
25+
"""
26+
file_lst = []
27+
cache_directory_abs = os.path.abspath(cache_directory)
28+
for dirpath, _, filenames in os.walk(cache_directory_abs):
29+
file_lst += [os.path.join(dirpath, f) for f in filenames if f.endswith("_o.h5")]
30+
return file_lst
31+
32+
1633
def get_cache_data(cache_directory: str) -> list[dict]:
1734
"""
1835
Collect all HDF5 files in the cache directory
@@ -27,15 +44,13 @@ def get_cache_data(cache_directory: str) -> list[dict]:
2744
import numpy as np
2845

2946
file_lst = []
30-
for task_key in os.listdir(cache_directory):
31-
file_name = os.path.join(cache_directory, task_key)
32-
if task_key[-5:] == "_o.h5":
33-
with h5py.File(file_name, "r") as hdf:
34-
file_content_dict = {
35-
key: cloudpickle.loads(np.void(hdf["/" + key]))
36-
for key in group_dict.values()
37-
if key in hdf
38-
}
39-
file_content_dict["filename"] = file_name
40-
file_lst.append(file_content_dict)
47+
for file_name in get_cache_files(cache_directory=cache_directory):
48+
with h5py.File(file_name, "r") as hdf:
49+
file_content_dict = {
50+
key: cloudpickle.loads(np.void(hdf["/" + key]))
51+
for key in group_dict.values()
52+
if key in hdf
53+
}
54+
file_content_dict["filename"] = file_name
55+
file_lst.append(file_content_dict)
4156
return file_lst

executorlib/standalone/serialize.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def serialize_funct_h5(
3333
fn_args: Optional[list] = None,
3434
fn_kwargs: Optional[dict] = None,
3535
resource_dict: Optional[dict] = None,
36+
cache_key: Optional[str] = None,
3637
) -> tuple[str, dict]:
3738
"""
3839
Serialize a function and its arguments and keyword arguments into an HDF5 file.
@@ -51,6 +52,8 @@ def serialize_funct_h5(
5152
executor: None,
5253
hostname_localhost: False,
5354
}
55+
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
56+
overwritten by setting the cache_key.
5457
5558
Returns:
5659
Tuple[str, dict]: A tuple containing the task key and the serialized data.
@@ -62,8 +65,17 @@ def serialize_funct_h5(
6265
fn_kwargs = {}
6366
if resource_dict is None:
6467
resource_dict = {}
65-
binary_all = cloudpickle.dumps({"fn": fn, "args": fn_args, "kwargs": fn_kwargs})
66-
task_key = fn.__name__ + _get_hash(binary=binary_all)
68+
if cache_key is not None:
69+
task_key = cache_key
70+
else:
71+
binary_all = cloudpickle.dumps(
72+
{
73+
"fn": fn,
74+
"args": fn_args,
75+
"kwargs": fn_kwargs,
76+
}
77+
)
78+
task_key = fn.__name__ + _get_hash(binary=binary_all)
6779
data = {
6880
"fn": fn,
6981
"args": fn_args,

executorlib/task_scheduler/file/hdf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Any, Optional
23

34
import cloudpickle
@@ -16,7 +17,9 @@ def dump(file_name: Optional[str], data_dict: dict) -> None:
1617
data_dict (dict): dictionary containing the python function to be executed {"fn": ..., "args": (), "kwargs": {}}
1718
"""
1819
if file_name is not None:
19-
with h5py.File(file_name, "a") as fname:
20+
file_name_abs = os.path.abspath(file_name)
21+
os.makedirs(os.path.dirname(file_name_abs), exist_ok=True)
22+
with h5py.File(file_name_abs, "a") as fname:
2023
for data_key, data_value in data_dict.items():
2124
if data_key in group_dict:
2225
fname.create_dataset(

executorlib/task_scheduler/file/shared.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from concurrent.futures import Future
77
from typing import Any, Callable, Optional
88

9+
from executorlib.standalone.cache import get_cache_files
910
from executorlib.standalone.command import get_command_path
1011
from executorlib.standalone.serialize import serialize_funct_h5
1112
from executorlib.task_scheduler.file.hdf import dump, get_output
@@ -72,6 +73,7 @@ def execute_tasks_h5(
7273
terminate_function (Callable): The function to terminate the tasks.
7374
pysqa_config_directory (str, optional): path to the pysqa config directory (only for pysqa based backend).
7475
backend (str, optional): name of the backend used to spawn tasks.
76+
disable_dependencies (boolean): Disable resolving future objects during the submission.
7577
7678
Returns:
7779
None
@@ -101,15 +103,18 @@ def execute_tasks_h5(
101103
task_resource_dict.update(
102104
{k: v for k, v in resource_dict.items() if k not in task_resource_dict}
103105
)
106+
cache_key = task_resource_dict.pop("cache_key", None)
104107
task_key, data_dict = serialize_funct_h5(
105108
fn=task_dict["fn"],
106109
fn_args=task_args,
107110
fn_kwargs=task_kwargs,
108111
resource_dict=task_resource_dict,
112+
cache_key=cache_key,
109113
)
110114
if task_key not in memory_dict:
111-
if task_key + "_o.h5" not in os.listdir(cache_directory):
112-
os.makedirs(cache_directory, exist_ok=True)
115+
if task_key + "_o.h5" not in get_cache_files(
116+
cache_directory=cache_directory
117+
):
113118
file_name = os.path.join(cache_directory, task_key + "_i.h5")
114119
dump(file_name=file_name, data_dict=data_dict)
115120
if not disable_dependencies:

executorlib/task_scheduler/interactive/shared.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from typing import Callable, Optional
88

9+
from executorlib.standalone.cache import get_cache_files
910
from executorlib.standalone.command import get_command_path
1011
from executorlib.standalone.interactive.communication import (
1112
SocketInterface,
@@ -22,6 +23,7 @@ def execute_tasks(
2223
hostname_localhost: Optional[bool] = None,
2324
init_function: Optional[Callable] = None,
2425
cache_directory: Optional[str] = None,
26+
cache_key: Optional[str] = None,
2527
queue_join_on_shutdown: bool = True,
2628
log_obj_size: bool = False,
2729
**kwargs,
@@ -42,6 +44,8 @@ def execute_tasks(
4244
option to true
4345
init_function (Callable): optional function to preset arguments for functions which are submitted later
4446
cache_directory (str, optional): The directory to store cache files. Defaults to "executorlib_cache".
47+
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
48+
overwritten by setting the cache_key.
4549
queue_join_on_shutdown (bool): Join communication queue when thread is closed. Defaults to True.
4650
log_obj_size (bool): Enable debug mode which reports the size of the communicated objects.
4751
"""
@@ -76,6 +80,7 @@ def execute_tasks(
7680
task_dict=task_dict,
7781
future_queue=future_queue,
7882
cache_directory=cache_directory,
83+
cache_key=cache_key,
7984
)
8085

8186

@@ -132,6 +137,7 @@ def _execute_task_with_cache(
132137
task_dict: dict,
133138
future_queue: queue.Queue,
134139
cache_directory: str,
140+
cache_key: Optional[str] = None,
135141
):
136142
"""
137143
Execute the task in the task_dict by communicating it via the interface using the cache in the cache directory.
@@ -142,6 +148,8 @@ def _execute_task_with_cache(
142148
{"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}}
143149
future_queue (Queue): Queue for receiving new tasks.
144150
cache_directory (str): The directory to store cache files.
151+
cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be
152+
overwritten by setting the cache_key.
145153
"""
146154
from executorlib.task_scheduler.file.hdf import dump, get_output
147155

@@ -150,10 +158,11 @@ def _execute_task_with_cache(
150158
fn_args=task_dict["args"],
151159
fn_kwargs=task_dict["kwargs"],
152160
resource_dict=task_dict.get("resource_dict", {}),
161+
cache_key=cache_key,
153162
)
154163
os.makedirs(cache_directory, exist_ok=True)
155-
file_name = os.path.join(cache_directory, task_key + "_o.h5")
156-
if task_key + "_o.h5" not in os.listdir(cache_directory):
164+
file_name = os.path.abspath(os.path.join(cache_directory, task_key + "_o.h5"))
165+
if file_name not in get_cache_files(cache_directory=cache_directory):
157166
f = task_dict.pop("future")
158167
if f.set_running_or_notify_cancel():
159168
try:

tests/test_cache_fileexecutor_serial.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def test_executor_mixed(self):
4242
self.assertEqual(fs1.result(), 3)
4343
self.assertTrue(fs1.done())
4444

45+
def test_executor_mixed_cache_key(self):
46+
with FileTaskScheduler(execute_function=execute_in_subprocess) as exe:
47+
fs1 = exe.submit(my_funct, 1, b=2, resource_dict={"cache_key": "a/b/c"})
48+
self.assertFalse(fs1.done())
49+
self.assertEqual(fs1.result(), 3)
50+
self.assertTrue(fs1.done())
51+
4552
def test_executor_dependence_mixed(self):
4653
with FileTaskScheduler(execute_function=execute_in_subprocess) as exe:
4754
fs1 = exe.submit(my_funct, 1, b=2)

tests/test_singlenodeexecutor_cache.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def test_cache_data(self):
3434
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
3535
)
3636

37+
def test_cache_key(self):
38+
cache_directory = os.path.abspath("executorlib_cache")
39+
with SingleNodeExecutor(cache_directory=cache_directory) as exe:
40+
self.assertTrue(exe)
41+
future_lst = [exe.submit(sum, [i, i], resource_dict={"cache_key": "same/j" + str(i)}) for i in range(1, 4)]
42+
result_lst = [f.result() for f in future_lst]
43+
44+
cache_lst = get_cache_data(cache_directory=cache_directory)
45+
for entry in cache_lst:
46+
self.assertTrue("same" in entry['filename'])
47+
self.assertEqual(sum([c["output"] for c in cache_lst]), sum(result_lst))
48+
self.assertEqual(
49+
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
50+
)
51+
3752
def test_cache_error(self):
3853
cache_directory = os.path.abspath("cache_error")
3954
with SingleNodeExecutor(cache_directory=cache_directory) as exe:

0 commit comments

Comments
 (0)