From b761ac0fc0c899124bdeb6e01bc2f8319089ad4d Mon Sep 17 00:00:00 2001 From: "r.jaepel" Date: Mon, 5 May 2025 14:40:57 +0200 Subject: [PATCH 1/8] Add .save_as_python_script method and test --- cadet/h5.py | 94 ++++++++++++++++++++++++++++++++++++ tests/test_save_as_python.py | 64 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tests/test_save_as_python.py diff --git a/cadet/h5.py b/cadet/h5.py index c3a65a2..c8b38c8 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -184,6 +184,32 @@ def save(self, lock: bool = False) -> None: else: raise ValueError("Filename must be set before save can be used") + def save_as_python_script(self, filename: str, only_return_pythonic_representation=False): + if not filename.endswith(".py"): + raise Warning(f"The filename given to .save_as_python_script isn't a python file name.") + + code_lines_list = [ + "import numpy", + "from cadet import Cadet", + "", + "sim = Cadet()", + "root = sim.root", + ] + + code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root, + current_lines_list=code_lines_list, + prefix="root") + + filename_for_reproduced_h5_file = filename.replace(".py", ".h5") + code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'") + code_lines_list.append("sim.save()") + + if not only_return_pythonic_representation: + with open(filename, "w") as handle: + handle.writelines([line + "\n" for line in code_lines_list]) + else: + return code_lines_list + def delete_file(self) -> None: """Delete the file associated with the current instance.""" if self.filename is not None: @@ -520,3 +546,71 @@ def recursively_save(h5file: h5py.File, path: str, dic: Dict, func: callable) -> ) else: raise + + +def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None): + """ + Recursively turn a nested dictionary or addict.Dict into a list of Python code that + can generate the nested dictionary. + + :param dictionary: + :param current_lines_list: + :param prefix_list: + :return: list of Python code lines + """ + + def merge_to_absolute_key(prefix, key): + """ + Combine key and prefix to "prefix.key" except if there is no prefix, then return key + """ + if prefix is None: + return key + else: + return f"{prefix}.{key}" + + def clean_up_key(absolute_key: str): + """ + Remove problematic phrases from key, such as blank "return" + + :param absolute_key: + :return: + """ + absolute_key = absolute_key.replace(".return", "['return']") + return absolute_key + + def get_pythonic_representation_of_value(value): + """ + Use repr() to get a pythonic representation of the value + and add "np." to "array" and "float64" + + """ + value_representation = repr(value) + value_representation = value_representation.replace("array", "numpy.array") + value_representation = value_representation.replace("float64", "numpy.float64") + try: + eval(value_representation) + except NameError as e: + raise ValueError( + f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n" + f"Please report this to the CADET-Python developers.") from e + + return value_representation + + if current_lines_list is None: + current_lines_list = [] + + for key in sorted(dictionary.keys()): + value = dictionary[key] + + absolute_key = merge_to_absolute_key(prefix, key) + + if type(value) in (dict, Dict): + current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key) + else: + value_representation = get_pythonic_representation_of_value(value) + + absolute_key = clean_up_key(absolute_key) + + current_lines_list.append(f"{absolute_key} = {value_representation}") + + return current_lines_list diff --git a/tests/test_save_as_python.py b/tests/test_save_as_python.py new file mode 100644 index 0000000..c59b52f --- /dev/null +++ b/tests/test_save_as_python.py @@ -0,0 +1,64 @@ +import tempfile + +import numpy as np +import pytest +from addict import Dict + +from cadet import Cadet + + +@pytest.fixture +def temp_cadet_file(): + """ + Create a new Cadet object for use in tests. + """ + model = Cadet() + + with tempfile.NamedTemporaryFile() as temp: + model.filename = temp + yield model + + +def test_save_as_python(temp_cadet_file): + """ + Test that the Cadet class raises a KeyError exception when duplicate keys are set on it. + """ + # initialize "sim" variable to be overwritten by the exec lines later + sim = Cadet() + + # Populate temp_cadet_file with all tricky cases currently known + temp_cadet_file.root.input.foo = 1 + temp_cadet_file.root.input.bar.baryon = np.arange(10) + temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9) + temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64) + temp_cadet_file.root.input["return"].split_foobar = 1 + + code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True) + + # remove code lines that save the file + code_lines = code_lines[:-2] + + # populate "sim" variable using the generated code lines + for line in code_lines: + exec(line) + + # test that "sim" is equal to "temp_cadet_file" + recursive_equality_check(sim.root, temp_cadet_file.root) + + +def recursive_equality_check(dict_a: dict, dict_b: dict): + assert dict_a.keys() == dict_b.keys() + for key in dict_a.keys(): + value_a = dict_a[key] + value_b = dict_b[key] + if type(value_a) in (dict, Dict): + recursive_equality_check(value_a, value_b) + elif type(value_a) == np.ndarray: + np.testing.assert_array_equal(value_a, value_b) + else: + assert value_a == value_b + return True + + +if __name__ == "__main__": + pytest.main() From 9acdad1d8e1ecc651da172ed7625ea1f0ea392e7 Mon Sep 17 00:00:00 2001 From: "r.jaepel" Date: Mon, 5 May 2025 15:01:18 +0200 Subject: [PATCH 2/8] remove eval() --- cadet/h5.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/cadet/h5.py b/cadet/h5.py index c8b38c8..a4f8769 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -587,13 +587,6 @@ def get_pythonic_representation_of_value(value): value_representation = repr(value) value_representation = value_representation.replace("array", "numpy.array") value_representation = value_representation.replace("float64", "numpy.float64") - try: - eval(value_representation) - except NameError as e: - raise ValueError( - f"Encountered a value of '{value_representation}' that can't be directly reproduced in python.\n" - f"Please report this to the CADET-Python developers.") from e - return value_representation if current_lines_list is None: From 9bc60360b5b6814ac72e87ce564f96b399d2e76d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Tue, 6 May 2025 15:28:45 +0200 Subject: [PATCH 3/8] fixup! Add .save_as_python_script method and test --- cadet/h5.py | 42 +++++++++++++++++++++++++++++++----- tests/test_save_as_python.py | 4 +++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/cadet/h5.py b/cadet/h5.py index a4f8769..956efe1 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -184,9 +184,38 @@ def save(self, lock: bool = False) -> None: else: raise ValueError("Filename must be set before save can be used") - def save_as_python_script(self, filename: str, only_return_pythonic_representation=False): + def save_as_python_script( + self, + filename: str, + only_return_pythonic_representation: bool = False + ) -> None | list[str]: + """ + Save the current state as a Python script. + + Parameters + ---------- + filename : str + The name of the file to save the script to. Must end with ".py". + only_return_pythonic_representation : bool, optional + If True, returns the Python code as a list of strings instead of writing + to a file. Defaults to False. + + Returns + ------- + None | list[str] + If `only_return_pythonic_representation` is True, returns a list of strings + representing the Python code. Otherwise, returns None. + + Raises + ------ + Warning + If the filename does not end with ".py". + + """ if not filename.endswith(".py"): - raise Warning(f"The filename given to .save_as_python_script isn't a python file name.") + raise Warning( + "Unexpected filename extension. Consider setting a '.py' file." + ) code_lines_list = [ "import numpy", @@ -196,9 +225,11 @@ def save_as_python_script(self, filename: str, only_return_pythonic_representati "root = sim.root", ] - code_lines_list = recursively_turn_dict_to_python_list(dictionary=self.root, - current_lines_list=code_lines_list, - prefix="root") + code_lines_list = recursively_turn_dict_to_python_list( + dictionary=self.root, + current_lines_list=code_lines_list, + prefix="root" + ) filename_for_reproduced_h5_file = filename.replace(".py", ".h5") code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'") @@ -207,6 +238,7 @@ def save_as_python_script(self, filename: str, only_return_pythonic_representati if not only_return_pythonic_representation: with open(filename, "w") as handle: handle.writelines([line + "\n" for line in code_lines_list]) + return else: return code_lines_list diff --git a/tests/test_save_as_python.py b/tests/test_save_as_python.py index c59b52f..4335dbf 100644 --- a/tests/test_save_as_python.py +++ b/tests/test_save_as_python.py @@ -33,7 +33,9 @@ def test_save_as_python(temp_cadet_file): temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64) temp_cadet_file.root.input["return"].split_foobar = 1 - code_lines = temp_cadet_file.save_as_python_script(filename="temp.py", only_return_pythonic_representation=True) + code_lines = temp_cadet_file.save_as_python_script( + filename="temp.py", only_return_pythonic_representation=True + ) # remove code lines that save the file code_lines = code_lines[:-2] From 92c9c8d2c88150d6a5f267bfc2216207fa2d478d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Tue, 6 May 2025 15:40:04 +0200 Subject: [PATCH 4/8] fixup! Add .save_as_python_script method and test --- cadet/h5.py | 17 ++++++++--------- tests/test_save_as_python.py | 6 +++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/cadet/h5.py b/cadet/h5.py index 956efe1..e10e06c 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -218,22 +218,21 @@ def save_as_python_script( ) code_lines_list = [ - "import numpy", - "from cadet import Cadet", + "import numpy as np", + f"from cadet import {self.__class__.__name__}", "", - "sim = Cadet()", - "root = sim.root", + f"model = {self.__class__.__name__}()", ] code_lines_list = recursively_turn_dict_to_python_list( dictionary=self.root, current_lines_list=code_lines_list, - prefix="root" + prefix="model.root" ) filename_for_reproduced_h5_file = filename.replace(".py", ".h5") - code_lines_list.append(f"sim.filename = '{filename_for_reproduced_h5_file}'") - code_lines_list.append("sim.save()") + code_lines_list.append(f"model.filename = '{filename_for_reproduced_h5_file}'") + code_lines_list.append("model.save()") if not only_return_pythonic_representation: with open(filename, "w") as handle: @@ -617,8 +616,8 @@ def get_pythonic_representation_of_value(value): """ value_representation = repr(value) - value_representation = value_representation.replace("array", "numpy.array") - value_representation = value_representation.replace("float64", "numpy.float64") + value_representation = value_representation.replace("array", "np.array") + value_representation = value_representation.replace("float64", "np.float64") return value_representation if current_lines_list is None: diff --git a/tests/test_save_as_python.py b/tests/test_save_as_python.py index 4335dbf..c0d30d2 100644 --- a/tests/test_save_as_python.py +++ b/tests/test_save_as_python.py @@ -24,7 +24,7 @@ def test_save_as_python(temp_cadet_file): Test that the Cadet class raises a KeyError exception when duplicate keys are set on it. """ # initialize "sim" variable to be overwritten by the exec lines later - sim = Cadet() + model = Cadet() # Populate temp_cadet_file with all tricky cases currently known temp_cadet_file.root.input.foo = 1 @@ -45,7 +45,7 @@ def test_save_as_python(temp_cadet_file): exec(line) # test that "sim" is equal to "temp_cadet_file" - recursive_equality_check(sim.root, temp_cadet_file.root) + recursive_equality_check(model.root, temp_cadet_file.root) def recursive_equality_check(dict_a: dict, dict_b: dict): @@ -63,4 +63,4 @@ def recursive_equality_check(dict_a: dict, dict_b: dict): if __name__ == "__main__": - pytest.main() + pytest.main([__file__]) From 80d5d00fc90e22d6f9702e9f437b43d8227512b5 Mon Sep 17 00:00:00 2001 From: "r.jaepel" Date: Wed, 7 May 2025 10:15:25 +0200 Subject: [PATCH 5/8] fixup! Add .save_as_python_script method and test --- cadet/h5.py | 68 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/cadet/h5.py b/cadet/h5.py index e10e06c..a0ffa0b 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -581,18 +581,40 @@ def recursively_save(h5file: h5py.File, path: str, dic: Dict, func: callable) -> def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None): """ - Recursively turn a nested dictionary or addict.Dict into a list of Python code that - can generate the nested dictionary. + Recursively convert a nested dictionary (including addict.Dict) into a list of Python code lines + that can regenerate the original nested structure. - :param dictionary: - :param current_lines_list: - :param prefix_list: - :return: list of Python code lines + Parameters + ---------- + dictionary : dict + The nested dictionary or addict.Dict to convert. + current_lines_list : list, optional + A list that accumulates the Python code lines as the recursion progresses. + If None, a new list is created. + prefix : str, optional + A prefix used to build fully-qualified variable names representing nested keys. + + Returns + ------- + list of str + List of Python code lines that, when executed, recreate the nested dictionary. """ def merge_to_absolute_key(prefix, key): """ - Combine key and prefix to "prefix.key" except if there is no prefix, then return key + Combine prefix and key into a dot-separated path unless the prefix is None. + + Parameters + ---------- + prefix : str or None + The existing path prefix. + key : str + The current key to append. + + Returns + ------- + str + Dot-separated key path if prefix is not None; otherwise, the key itself. """ if prefix is None: return key @@ -601,19 +623,35 @@ def merge_to_absolute_key(prefix, key): def clean_up_key(absolute_key: str): """ - Remove problematic phrases from key, such as blank "return" + Sanitize a key path by replacing problematic substrings like '.return'. - :param absolute_key: - :return: + Parameters + ---------- + absolute_key : str + A dot-separated key path. + + Returns + ------- + str + A cleaned key path with special keywords properly escaped. """ absolute_key = absolute_key.replace(".return", "['return']") return absolute_key def get_pythonic_representation_of_value(value): """ - Use repr() to get a pythonic representation of the value - and add "np." to "array" and "float64" + Convert a value to a Python code representation, with NumPy-style modifications. + Parameters + ---------- + value : any + The value to be represented. + + Returns + ------- + str + A string representation using `repr()`, with `array` replaced by `np.array` + and `float64` replaced by `np.float64`. """ value_representation = repr(value) value_representation = value_representation.replace("array", "np.array") @@ -629,7 +667,11 @@ def get_pythonic_representation_of_value(value): absolute_key = merge_to_absolute_key(prefix, key) if type(value) in (dict, Dict): - current_lines_list = recursively_turn_dict_to_python_list(value, current_lines_list, prefix=absolute_key) + current_lines_list = recursively_turn_dict_to_python_list( + value, + current_lines_list, + prefix=absolute_key + ) else: value_representation = get_pythonic_representation_of_value(value) From 50518abc8d97865da8c531bd173a5255128bd555 Mon Sep 17 00:00:00 2001 From: Jan Breuer Date: Wed, 7 May 2025 14:02:52 +0200 Subject: [PATCH 6/8] fixup! Fix float64 representation --- cadet/h5.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cadet/h5.py b/cadet/h5.py index a0ffa0b..63c4c9a 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -650,12 +650,10 @@ def get_pythonic_representation_of_value(value): Returns ------- str - A string representation using `repr()`, with `array` replaced by `np.array` - and `float64` replaced by `np.float64`. + A string representation using `repr()`, with `array` replaced by `np.array`. """ value_representation = repr(value) value_representation = value_representation.replace("array", "np.array") - value_representation = value_representation.replace("float64", "np.float64") return value_representation if current_lines_list is None: From 1416997a3074ff5cffb074f09754b8a2ccb36b8d Mon Sep 17 00:00:00 2001 From: "r.jaepel" Date: Thu, 8 May 2025 12:35:27 +0200 Subject: [PATCH 7/8] Allow long arrays in save_as_python --- cadet/h5.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cadet/h5.py b/cadet/h5.py index 63c4c9a..f73b1b1 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -6,6 +6,7 @@ from typing import Optional, Any import warnings +import numpy as np from addict import Dict with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) @@ -652,8 +653,13 @@ def get_pythonic_representation_of_value(value): str A string representation using `repr()`, with `array` replaced by `np.array`. """ - value_representation = repr(value) - value_representation = value_representation.replace("array", "np.array") + if isinstance(value, np.ndarray): + if len(value) > 1e7: + raise ValueError("Array is too long to be serialized") + value_representation = np.array2string(value, separator=',', threshold=int(1e7)) + value_representation = f"np.array({value_representation})" + else: + value_representation = repr(value) return value_representation if current_lines_list is None: From 9df19fe10dc4177183cd7a321273236273445f91 Mon Sep 17 00:00:00 2001 From: "r.jaepel" Date: Thu, 8 May 2025 12:35:57 +0200 Subject: [PATCH 8/8] Extend test of save_as_python --- tests/test_save_as_python.py | 88 ++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 19 deletions(-) diff --git a/tests/test_save_as_python.py b/tests/test_save_as_python.py index c0d30d2..87d8cc7 100644 --- a/tests/test_save_as_python.py +++ b/tests/test_save_as_python.py @@ -8,32 +8,50 @@ @pytest.fixture -def temp_cadet_file(): +def original_model(): """ Create a new Cadet object for use in tests. """ - model = Cadet() - with tempfile.NamedTemporaryFile() as temp: - model.filename = temp + model = Cadet().create_lwe(file_path=temp.name+".h5") + model.run_simulation() yield model -def test_save_as_python(temp_cadet_file): +def test_save_as_python(original_model): """ - Test that the Cadet class raises a KeyError exception when duplicate keys are set on it. + Test saving and regenerating a Cadet model as Python code. + + Verifies that a Cadet model can be serialized to a Python script and + accurately reconstructed by executing the generated script. This ensures + that model parameters, including arrays and edge-case values, are preserved. + + Parameters + ---------- + original_model : Cadet + A Cadet model instance to populate and serialize for testing. + + Raises + ------ + AssertionError + If the regenerated model does not match the original model within + a specified relative tolerance. """ - # initialize "sim" variable to be overwritten by the exec lines later + # initialize "model" variable to be overwritten by the exec lines later + # it needs to be called "model", as that is the variable that the generated code overwrites model = Cadet() - # Populate temp_cadet_file with all tricky cases currently known - temp_cadet_file.root.input.foo = 1 - temp_cadet_file.root.input.bar.baryon = np.arange(10) - temp_cadet_file.root.input.bar.barometer = np.linspace(0, 10, 9) - temp_cadet_file.root.input.bar.init_q = np.array([], dtype=np.float64) - temp_cadet_file.root.input["return"].split_foobar = 1 - - code_lines = temp_cadet_file.save_as_python_script( + # Populate original_model with all tricky cases currently known + original_model.root.input.foo = 1 + original_model.root.input.food = 1.9 + original_model.root.input.bar.baryon = np.arange(10) + original_model.root.input.bar.barometer = np.linspace(0, 10, 9) + original_model.root.input.bar.init_q = np.array([], dtype=np.float64) + original_model.root.input.bar.init_qt = np.array([0., 0.0011666666666666668, 0.0023333333333333335]) + original_model.root.input.bar.par_disc_type = np.array([b'EQUIDISTANT_PAR'], dtype='|S15') + original_model.root.input["return"].split_foobar = 1 + + code_lines = original_model.save_as_python_script( filename="temp.py", only_return_pythonic_representation=True ) @@ -45,18 +63,50 @@ def test_save_as_python(temp_cadet_file): exec(line) # test that "sim" is equal to "temp_cadet_file" - recursive_equality_check(model.root, temp_cadet_file.root) + recursive_equality_check(original_model.root, model.root, rtol=1e-5) -def recursive_equality_check(dict_a: dict, dict_b: dict): +def recursive_equality_check(dict_a: dict, dict_b: dict, rtol=1e-5): + """ + Recursively compare two nested dictionaries for equality. + + Compares the keys and values of two dictionaries. If a value is a nested + dictionary, the function recurses. NumPy arrays are compared using + `np.testing.assert_allclose`, except for byte strings which are compared + directly. + + Parameters + ---------- + dict_a : dict + First dictionary to compare. + dict_b : dict + Second dictionary to compare. + rtol : float, optional + Relative tolerance for comparing NumPy arrays, by default 1e-5. + + Returns + ------- + bool + True if the dictionaries are equal; otherwise, an assertion is raised. + + Raises + ------ + AssertionError + If keys do not match, or values are not equal within the given tolerance. + """ assert dict_a.keys() == dict_b.keys() for key in dict_a.keys(): value_a = dict_a[key] value_b = dict_b[key] if type(value_a) in (dict, Dict): recursive_equality_check(value_a, value_b) - elif type(value_a) == np.ndarray: - np.testing.assert_array_equal(value_a, value_b) + elif isinstance(value_a, np.ndarray): + # This catches cases where strings are stored in arrays, and the dtype S15 causes numpy problems + # which can happen if reading a simulation file back from an H5 file from disk + if value_a.dtype == np.dtype("S15") and len(value_a) == 1 and len(value_b) == 1: + assert value_a[0] == value_b[0] + else: + np.testing.assert_allclose(value_a, value_b, rtol=rtol) else: assert value_a == value_b return True