Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions cadet/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Any
import warnings

import numpy as np
from addict import Dict
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down Expand Up @@ -184,6 +185,63 @@ def save(self, lock: bool = False) -> None:
else:
raise ValueError("Filename must be set before save can be used")

def save_as_python_script(
self,
filename: str,
only_return_pythonic_representation: bool = False
) -> None | list[str]:
"""
Save the current state as a Python script.

Parameters
----------
filename : str
The name of the file to save the script to. Must end with ".py".
only_return_pythonic_representation : bool, optional
If True, returns the Python code as a list of strings instead of writing
to a file. Defaults to False.

Returns
-------
None | list[str]
If `only_return_pythonic_representation` is True, returns a list of strings
representing the Python code. Otherwise, returns None.

Raises
------
Warning
If the filename does not end with ".py".

"""
if not filename.endswith(".py"):
raise Warning(
"Unexpected filename extension. Consider setting a '.py' file."
)

code_lines_list = [
"import numpy as np",
f"from cadet import {self.__class__.__name__}",
"",
f"model = {self.__class__.__name__}()",
]

code_lines_list = recursively_turn_dict_to_python_list(
dictionary=self.root,
current_lines_list=code_lines_list,
prefix="model.root"
)

filename_for_reproduced_h5_file = filename.replace(".py", ".h5")
code_lines_list.append(f"model.filename = '{filename_for_reproduced_h5_file}'")
code_lines_list.append("model.save()")

if not only_return_pythonic_representation:
with open(filename, "w") as handle:
handle.writelines([line + "\n" for line in code_lines_list])
return
else:
return code_lines_list

def delete_file(self) -> None:
"""Delete the file associated with the current instance."""
if self.filename is not None:
Expand Down Expand Up @@ -520,3 +578,109 @@ def recursively_save(h5file: h5py.File, path: str, dic: Dict, func: callable) ->
)
else:
raise


def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None):
"""
Recursively convert a nested dictionary (including addict.Dict) into a list of Python code lines
that can regenerate the original nested structure.

Parameters
----------
dictionary : dict
The nested dictionary or addict.Dict to convert.
current_lines_list : list, optional
A list that accumulates the Python code lines as the recursion progresses.
If None, a new list is created.
prefix : str, optional
A prefix used to build fully-qualified variable names representing nested keys.

Returns
-------
list of str
List of Python code lines that, when executed, recreate the nested dictionary.
"""

def merge_to_absolute_key(prefix, key):
"""
Combine prefix and key into a dot-separated path unless the prefix is None.

Parameters
----------
prefix : str or None
The existing path prefix.
key : str
The current key to append.

Returns
-------
str
Dot-separated key path if prefix is not None; otherwise, the key itself.
"""
if prefix is None:
return key
else:
return f"{prefix}.{key}"

def clean_up_key(absolute_key: str):
"""
Sanitize a key path by replacing problematic substrings like '.return'.

Parameters
----------
absolute_key : str
A dot-separated key path.

Returns
-------
str
A cleaned key path with special keywords properly escaped.
"""
absolute_key = absolute_key.replace(".return", "['return']")
return absolute_key

def get_pythonic_representation_of_value(value):
"""
Convert a value to a Python code representation, with NumPy-style modifications.

Parameters
----------
value : any
The value to be represented.

Returns
-------
str
A string representation using `repr()`, with `array` replaced by `np.array`.
"""
if isinstance(value, np.ndarray):
if len(value) > 1e7:
raise ValueError("Array is too long to be serialized")
value_representation = np.array2string(value, separator=',', threshold=int(1e7))
value_representation = f"np.array({value_representation})"
else:
value_representation = repr(value)
return value_representation

if current_lines_list is None:
current_lines_list = []

for key in sorted(dictionary.keys()):
value = dictionary[key]

absolute_key = merge_to_absolute_key(prefix, key)

if type(value) in (dict, Dict):
current_lines_list = recursively_turn_dict_to_python_list(
value,
current_lines_list,
prefix=absolute_key
)
else:
value_representation = get_pythonic_representation_of_value(value)

absolute_key = clean_up_key(absolute_key)

current_lines_list.append(f"{absolute_key} = {value_representation}")

return current_lines_list
116 changes: 116 additions & 0 deletions tests/test_save_as_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import tempfile

import numpy as np
import pytest
from addict import Dict

from cadet import Cadet


@pytest.fixture
def original_model():
"""
Create a new Cadet object for use in tests.
"""
with tempfile.NamedTemporaryFile() as temp:
model = Cadet().create_lwe(file_path=temp.name+".h5")
model.run_simulation()
yield model


def test_save_as_python(original_model):
"""
Test saving and regenerating a Cadet model as Python code.

Verifies that a Cadet model can be serialized to a Python script and
accurately reconstructed by executing the generated script. This ensures
that model parameters, including arrays and edge-case values, are preserved.

Parameters
----------
original_model : Cadet
A Cadet model instance to populate and serialize for testing.

Raises
------
AssertionError
If the regenerated model does not match the original model within
a specified relative tolerance.
"""
# initialize "model" variable to be overwritten by the exec lines later
# it needs to be called "model", as that is the variable that the generated code overwrites
model = Cadet()

# Populate original_model with all tricky cases currently known
original_model.root.input.foo = 1
original_model.root.input.food = 1.9
original_model.root.input.bar.baryon = np.arange(10)
original_model.root.input.bar.barometer = np.linspace(0, 10, 9)
original_model.root.input.bar.init_q = np.array([], dtype=np.float64)
original_model.root.input.bar.init_qt = np.array([0., 0.0011666666666666668, 0.0023333333333333335])
original_model.root.input.bar.par_disc_type = np.array([b'EQUIDISTANT_PAR'], dtype='|S15')
original_model.root.input["return"].split_foobar = 1

code_lines = original_model.save_as_python_script(
filename="temp.py", only_return_pythonic_representation=True
)

# remove code lines that save the file
code_lines = code_lines[:-2]

# populate "sim" variable using the generated code lines
for line in code_lines:
exec(line)

# test that "sim" is equal to "temp_cadet_file"
recursive_equality_check(original_model.root, model.root, rtol=1e-5)


def recursive_equality_check(dict_a: dict, dict_b: dict, rtol=1e-5):
"""
Recursively compare two nested dictionaries for equality.

Compares the keys and values of two dictionaries. If a value is a nested
dictionary, the function recurses. NumPy arrays are compared using
`np.testing.assert_allclose`, except for byte strings which are compared
directly.

Parameters
----------
dict_a : dict
First dictionary to compare.
dict_b : dict
Second dictionary to compare.
rtol : float, optional
Relative tolerance for comparing NumPy arrays, by default 1e-5.

Returns
-------
bool
True if the dictionaries are equal; otherwise, an assertion is raised.

Raises
------
AssertionError
If keys do not match, or values are not equal within the given tolerance.
"""
assert dict_a.keys() == dict_b.keys()
for key in dict_a.keys():
value_a = dict_a[key]
value_b = dict_b[key]
if type(value_a) in (dict, Dict):
recursive_equality_check(value_a, value_b)
elif isinstance(value_a, np.ndarray):
# This catches cases where strings are stored in arrays, and the dtype S15 causes numpy problems
# which can happen if reading a simulation file back from an H5 file from disk
if value_a.dtype == np.dtype("S15") and len(value_a) == 1 and len(value_b) == 1:
assert value_a[0] == value_b[0]
else:
np.testing.assert_allclose(value_a, value_b, rtol=rtol)
else:
assert value_a == value_b
return True


if __name__ == "__main__":
pytest.main([__file__])
Loading