diff --git a/cadet/h5.py b/cadet/h5.py index c3a65a2..f73b1b1 100644 --- a/cadet/h5.py +++ b/cadet/h5.py @@ -6,6 +6,7 @@ from typing import Optional, Any import warnings +import numpy as np from addict import Dict with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) @@ -184,6 +185,63 @@ def save(self, lock: bool = False) -> None: else: raise ValueError("Filename must be set before save can be used") + def save_as_python_script( + self, + filename: str, + only_return_pythonic_representation: bool = False + ) -> None | list[str]: + """ + Save the current state as a Python script. + + Parameters + ---------- + filename : str + The name of the file to save the script to. Must end with ".py". + only_return_pythonic_representation : bool, optional + If True, returns the Python code as a list of strings instead of writing + to a file. Defaults to False. + + Returns + ------- + None | list[str] + If `only_return_pythonic_representation` is True, returns a list of strings + representing the Python code. Otherwise, returns None. + + Raises + ------ + Warning + If the filename does not end with ".py". + + """ + if not filename.endswith(".py"): + raise Warning( + "Unexpected filename extension. Consider setting a '.py' file." + ) + + code_lines_list = [ + "import numpy as np", + f"from cadet import {self.__class__.__name__}", + "", + f"model = {self.__class__.__name__}()", + ] + + code_lines_list = recursively_turn_dict_to_python_list( + dictionary=self.root, + current_lines_list=code_lines_list, + prefix="model.root" + ) + + filename_for_reproduced_h5_file = filename.replace(".py", ".h5") + code_lines_list.append(f"model.filename = '{filename_for_reproduced_h5_file}'") + code_lines_list.append("model.save()") + + if not only_return_pythonic_representation: + with open(filename, "w") as handle: + handle.writelines([line + "\n" for line in code_lines_list]) + return + else: + return code_lines_list + def delete_file(self) -> None: """Delete the file associated with the current instance.""" if self.filename is not None: @@ -520,3 +578,109 @@ def recursively_save(h5file: h5py.File, path: str, dic: Dict, func: callable) -> ) else: raise + + +def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None): + """ + Recursively convert a nested dictionary (including addict.Dict) into a list of Python code lines + that can regenerate the original nested structure. + + Parameters + ---------- + dictionary : dict + The nested dictionary or addict.Dict to convert. + current_lines_list : list, optional + A list that accumulates the Python code lines as the recursion progresses. + If None, a new list is created. + prefix : str, optional + A prefix used to build fully-qualified variable names representing nested keys. + + Returns + ------- + list of str + List of Python code lines that, when executed, recreate the nested dictionary. + """ + + def merge_to_absolute_key(prefix, key): + """ + Combine prefix and key into a dot-separated path unless the prefix is None. + + Parameters + ---------- + prefix : str or None + The existing path prefix. + key : str + The current key to append. + + Returns + ------- + str + Dot-separated key path if prefix is not None; otherwise, the key itself. + """ + if prefix is None: + return key + else: + return f"{prefix}.{key}" + + def clean_up_key(absolute_key: str): + """ + Sanitize a key path by replacing problematic substrings like '.return'. + + Parameters + ---------- + absolute_key : str + A dot-separated key path. + + Returns + ------- + str + A cleaned key path with special keywords properly escaped. + """ + absolute_key = absolute_key.replace(".return", "['return']") + return absolute_key + + def get_pythonic_representation_of_value(value): + """ + Convert a value to a Python code representation, with NumPy-style modifications. + + Parameters + ---------- + value : any + The value to be represented. + + Returns + ------- + str + A string representation using `repr()`, with `array` replaced by `np.array`. + """ + if isinstance(value, np.ndarray): + if len(value) > 1e7: + raise ValueError("Array is too long to be serialized") + value_representation = np.array2string(value, separator=',', threshold=int(1e7)) + value_representation = f"np.array({value_representation})" + else: + value_representation = repr(value) + return value_representation + + if current_lines_list is None: + current_lines_list = [] + + for key in sorted(dictionary.keys()): + value = dictionary[key] + + absolute_key = merge_to_absolute_key(prefix, key) + + if type(value) in (dict, Dict): + current_lines_list = recursively_turn_dict_to_python_list( + value, + current_lines_list, + prefix=absolute_key + ) + else: + value_representation = get_pythonic_representation_of_value(value) + + absolute_key = clean_up_key(absolute_key) + + current_lines_list.append(f"{absolute_key} = {value_representation}") + + return current_lines_list diff --git a/tests/test_save_as_python.py b/tests/test_save_as_python.py new file mode 100644 index 0000000..87d8cc7 --- /dev/null +++ b/tests/test_save_as_python.py @@ -0,0 +1,116 @@ +import tempfile + +import numpy as np +import pytest +from addict import Dict + +from cadet import Cadet + + +@pytest.fixture +def original_model(): + """ + Create a new Cadet object for use in tests. + """ + with tempfile.NamedTemporaryFile() as temp: + model = Cadet().create_lwe(file_path=temp.name+".h5") + model.run_simulation() + yield model + + +def test_save_as_python(original_model): + """ + Test saving and regenerating a Cadet model as Python code. + + Verifies that a Cadet model can be serialized to a Python script and + accurately reconstructed by executing the generated script. This ensures + that model parameters, including arrays and edge-case values, are preserved. + + Parameters + ---------- + original_model : Cadet + A Cadet model instance to populate and serialize for testing. + + Raises + ------ + AssertionError + If the regenerated model does not match the original model within + a specified relative tolerance. + """ + # initialize "model" variable to be overwritten by the exec lines later + # it needs to be called "model", as that is the variable that the generated code overwrites + model = Cadet() + + # Populate original_model with all tricky cases currently known + original_model.root.input.foo = 1 + original_model.root.input.food = 1.9 + original_model.root.input.bar.baryon = np.arange(10) + original_model.root.input.bar.barometer = np.linspace(0, 10, 9) + original_model.root.input.bar.init_q = np.array([], dtype=np.float64) + original_model.root.input.bar.init_qt = np.array([0., 0.0011666666666666668, 0.0023333333333333335]) + original_model.root.input.bar.par_disc_type = np.array([b'EQUIDISTANT_PAR'], dtype='|S15') + original_model.root.input["return"].split_foobar = 1 + + code_lines = original_model.save_as_python_script( + filename="temp.py", only_return_pythonic_representation=True + ) + + # remove code lines that save the file + code_lines = code_lines[:-2] + + # populate "sim" variable using the generated code lines + for line in code_lines: + exec(line) + + # test that "sim" is equal to "temp_cadet_file" + recursive_equality_check(original_model.root, model.root, rtol=1e-5) + + +def recursive_equality_check(dict_a: dict, dict_b: dict, rtol=1e-5): + """ + Recursively compare two nested dictionaries for equality. + + Compares the keys and values of two dictionaries. If a value is a nested + dictionary, the function recurses. NumPy arrays are compared using + `np.testing.assert_allclose`, except for byte strings which are compared + directly. + + Parameters + ---------- + dict_a : dict + First dictionary to compare. + dict_b : dict + Second dictionary to compare. + rtol : float, optional + Relative tolerance for comparing NumPy arrays, by default 1e-5. + + Returns + ------- + bool + True if the dictionaries are equal; otherwise, an assertion is raised. + + Raises + ------ + AssertionError + If keys do not match, or values are not equal within the given tolerance. + """ + assert dict_a.keys() == dict_b.keys() + for key in dict_a.keys(): + value_a = dict_a[key] + value_b = dict_b[key] + if type(value_a) in (dict, Dict): + recursive_equality_check(value_a, value_b) + elif isinstance(value_a, np.ndarray): + # This catches cases where strings are stored in arrays, and the dtype S15 causes numpy problems + # which can happen if reading a simulation file back from an H5 file from disk + if value_a.dtype == np.dtype("S15") and len(value_a) == 1 and len(value_b) == 1: + assert value_a[0] == value_b[0] + else: + np.testing.assert_allclose(value_a, value_b, rtol=rtol) + else: + assert value_a == value_b + return True + + +if __name__ == "__main__": + pytest.main([__file__])