diff --git a/bindings/rascal/utils/__init__.py b/bindings/rascal/utils/__init__.py index 824021e8e..e98b98ef5 100644 --- a/bindings/rascal/utils/__init__.py +++ b/bindings/rascal/utils/__init__.py @@ -6,6 +6,7 @@ get_supported_io_versions, dump_obj, load_obj, + json_dumps_frame, ) # Warning potential dependency loop: FPS imports models, which imports KRR, diff --git a/bindings/rascal/utils/io.py b/bindings/rascal/utils/io.py index cbcbe9b9b..37d971567 100644 --- a/bindings/rascal/utils/io.py +++ b/bindings/rascal/utils/io.py @@ -3,6 +3,7 @@ from collections import Iterable import numpy as np import json +import datetime from copy import deepcopy from abc import ABC, abstractmethod @@ -422,3 +423,58 @@ def _load_npy(data, path): if len(v) == 2: if "npy" == v[0]: data[k] = np.array(v[1]) + + +class RascalEncoder(json.JSONEncoder): + def default(self, obj): + if hasattr(obj, "todict"): + d = obj.todict() + + if not isinstance(d, dict): + raise RuntimeError( + "todict() of {} returned object of type {} " + "but should have returned dict".format(obj, type(d)) + ) + if hasattr(obj, "ase_objtype"): + d["__ase_objtype__"] = obj.ase_objtype + + return d + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, datetime.datetime): + return {"__datetime__": obj.isoformat()} + if isinstance(obj, complex): + return {"__complex__": (obj.real, obj.imag)} + return json.JSONEncoder.default(self, obj) + + +def json_dumps_frame(frames, **json_dumps_kwargs): + """Serialize frames to a JSON formatted string. + + Parameters + ---------- + frames : list(ase.Atoms) or ase.Atoms + List of atomic structures (or single one) to be dumped to a json + + json_dumps_kwargs : dict + List of arguments forwarded to json.dumps + + Return + ------ + T + """ + if type(frames) is not list: + frames = [frames] + + json_frames = {} + for i, frame in enumerate(frames): + json_frames[str(i)] = json.loads(json.dumps(frame, cls=RascalEncoder)) + + json_frames["ids"] = list(range(len(frames))) + json_frames["nextid"] = len(frames) + + return json.dumps(json_frames, **json_dumps_kwargs) diff --git a/tests/python/python_binding_tests.py b/tests/python/python_binding_tests.py index ac6281e6a..f0d48ac43 100755 --- a/tests/python/python_binding_tests.py +++ b/tests/python/python_binding_tests.py @@ -17,7 +17,8 @@ from python_models_test import TestNumericalKernelGradient, TestCosineKernel from python_math_test import TestMath from python_test_sparsify_fps import TestFPS -from python_utils_test import TestOptimalRadialBasis +from python_utils_test import TestOptimalRadialBasis, TestIO + from md_calculator_test import TestGenericMD diff --git a/tests/python/python_utils_test.py b/tests/python/python_utils_test.py index 41e8c81d2..c23e27a4e 100644 --- a/tests/python/python_utils_test.py +++ b/tests/python/python_utils_test.py @@ -7,17 +7,23 @@ get_radial_basis_pca, get_radial_basis_projections, get_optimal_radial_basis_hypers, + json_dumps_frame, ) +from rascal.lib import neighbour_list +from rascal.neighbourlist import base from test_utils import load_json_frame, BoxList, Box, dot +import tempfile import unittest import numpy as np import sys import os import json +import tempfile from copy import copy, deepcopy from scipy.stats import ortho_group import pickle +import ase.io rascal_reference_path = "reference_data" inputs_path = os.path.join(rascal_reference_path, "inputs") @@ -91,3 +97,37 @@ def test_hypers_construction(self): soap_feats_2 = soap_opt_2.transform(self.frames).get_features(soap_opt_2) self.assertTrue(np.allclose(soap_feats, soap_feats_2)) + + +class TestIO(unittest.TestCase): + def setUp(self): + self.fns = [ + os.path.join(inputs_path, "CaCrP2O7_mvc-11955_symmetrized.json"), + os.path.join(inputs_path, "SiC_moissanite_supercell.json"), + os.path.join(inputs_path, "methane.json"), + ] + + def test_json_dumps_frame(self): + """ + Checks if json file decoded by RascalEncoder in dumps_frame can be read + by rascal + """ + nl_options = [ + dict(name="centers", args=dict()), + dict(name="neighbourlist", args=dict(cutoff=3)), + dict(name="centercontribution", args=dict()), + dict(name="strict", args=dict(cutoff=3)), + ] + managers = base.StructureCollectionFactory(nl_options) + for fn in self.fns: + frame = ase.io.read(fn) + dumped_json = json_dumps_frame(frame) + tmp = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) + tmp.write(dumped_json) + try: + managers.add_structures(tmp.name) + tmp.close() + os.unlink(tmp.name) + except: + tmp.close() + os.unlink(tmp.name)