Skip to content
Closed
8 changes: 8 additions & 0 deletions dpdata/deepmd/mixed.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check other files in this directory to see how we support ANY new keys with the plugin system.

if labels:
dtypes = dpdata.system.LabeledSystem.DTYPES
else:
dtypes = dpdata.system.System.DTYPES
for dtype in dtypes:
if dtype.name in (
"atom_numbs",

Copy link
Contributor Author

@anyangml anyangml Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check other files in this directory to see how we support ANY new keys with the plugin system.

@njzjz I am not quite following. Are you suggesting that there are missing changes need to be added, or you want the feature to be implemented using a different approach. This PR only fix the fparam bug in mixed systems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fparam is implemented with plugins, so I don't suggest handling it specially. Ideally, we should handle any registered data type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fparam is implemented with plugins, so I don't suggest handling it specially. Ideally, we should handle any registered data type.

If I understand correctly, you are suggesting refactoring the temp_idx logic into comp.py, not only for fparam but also include all the other dtypes?

temp_idx = np.arange(all_real_atom_types_concat.shape[0])[
(all_real_atom_types_concat == all_real_atom_types_concat[0]).all(-1)
]
rest_idx = np.arange(all_real_atom_types_concat.shape[0])[
(all_real_atom_types_concat != all_real_atom_types_concat[0]).any(-1)
]
temp_data = data.copy()
temp_data["atom_names"] = data["atom_names"].copy()
temp_data["atom_numbs"] = temp_atom_numbs
temp_data["atom_types"] = all_real_atom_types_concat[0]
all_real_atom_types_concat = all_real_atom_types_concat[rest_idx]
temp_data["cells"] = all_cells_concat[temp_idx]
all_cells_concat = all_cells_concat[rest_idx]
temp_data["coords"] = all_coords_concat[temp_idx]
all_coords_concat = all_coords_concat[rest_idx]
if labels:
if all_eners_concat is not None and all_eners_concat.size > 0:
temp_data["energies"] = all_eners_concat[temp_idx]
all_eners_concat = all_eners_concat[rest_idx]
if all_forces_concat is not None and all_forces_concat.size > 0:
temp_data["forces"] = all_forces_concat[temp_idx]
all_forces_concat = all_forces_concat[rest_idx]
if all_virs_concat is not None and all_virs_concat.size > 0:
temp_data["virials"] = all_virs_concat[temp_idx]
all_virs_concat = all_virs_concat[rest_idx]
data_list.append(temp_data)

That probably be done in a separate PR as a refactor. This PR only aims to fix the bug.

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def to_system_data(folder, type_map=None, labels=True):
all_real_atom_types_concat = index_map[all_real_atom_types_concat]
all_cells_concat = data["cells"]
all_coords_concat = data["coords"]
all_fparam_concat = data.get("fparam", None)
all_aparam_concat = data.get("aparam", None)
if labels:
all_eners_concat = data.get("energies")
all_forces_concat = data.get("forces")
Expand Down Expand Up @@ -56,6 +58,12 @@ def to_system_data(folder, type_map=None, labels=True):
all_cells_concat = all_cells_concat[rest_idx]
temp_data["coords"] = all_coords_concat[temp_idx]
all_coords_concat = all_coords_concat[rest_idx]
if all_fparam_concat is not None:
temp_data["fparam"] = all_fparam_concat[temp_idx]
all_fparam_concat = all_fparam_concat[rest_idx]
if all_aparam_concat is not None:
temp_data["aparam"] = all_aparam_concat[temp_idx]
all_aparam_concat = all_aparam_concat[rest_idx]
if labels:
if all_eners_concat is not None and all_eners_concat.size > 0:
temp_data["energies"] = all_eners_concat[temp_idx]
Expand Down
144 changes: 142 additions & 2 deletions tests/test_deepmd_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
)
from context import dpdata

from dpdata.data_type import (
Axis,
DataType,
)


class TestMixedMultiSystemsDumpLoad(
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
Expand Down Expand Up @@ -455,5 +460,140 @@ def tearDown(self):
shutil.rmtree("tmp.deepmd.mixed.single")


if __name__ == "__main__":
unittest.main()
class TestMixedSystemWithFparamAparam(
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
):
def setUp(self):
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 6

new_datatypes = [
DataType(
"fparam",
np.ndarray,
shape=(Axis.NFRAMES, 2),
required=False,
),
DataType(
"aparam",
np.ndarray,
shape=(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
),
]

for datatype in new_datatypes:
dpdata.System.register_data_type(datatype)
dpdata.LabeledSystem.register_data_type(datatype)

# C1H4
system_1 = dpdata.LabeledSystem(
"gaussian/methane.gaussianlog", fmt="gaussian/log"
)

# C1H3
system_2 = dpdata.LabeledSystem(
"gaussian/methane_sub.gaussianlog", fmt="gaussian/log"
)

tmp_data_1 = system_1.data.copy()
nframes_1 = tmp_data_1["coords"].shape[0]
natoms_1 = tmp_data_1["atom_types"].shape[0]
tmp_data_1["fparam"] = np.random.random([nframes_1, 2])
tmp_data_1["aparam"] = np.random.random([nframes_1, natoms_1, 3])
system_1_with_params = dpdata.LabeledSystem(data=tmp_data_1)

tmp_data_2 = system_2.data.copy()
nframes_2 = tmp_data_2["coords"].shape[0]
natoms_2 = tmp_data_2["atom_types"].shape[0]
tmp_data_2["fparam"] = np.random.random([nframes_2, 2])
tmp_data_2["aparam"] = np.random.random([nframes_2, natoms_2, 3])
system_2_with_params = dpdata.LabeledSystem(data=tmp_data_2)

tmp_data_3 = system_1.data.copy()
nframes_3 = tmp_data_3["coords"].shape[0]
tmp_data_3["atom_numbs"] = [1, 1, 1, 2]
tmp_data_3["atom_names"] = ["C", "H", "A", "B"]
tmp_data_3["atom_types"] = np.array([0, 1, 2, 3, 3])
natoms_3 = len(tmp_data_3["atom_types"])
tmp_data_3["fparam"] = np.random.random([nframes_3, 2])
tmp_data_3["aparam"] = np.random.random([nframes_3, natoms_3, 3])
# C1H1A1B2 with params
system_3_with_params = dpdata.LabeledSystem(data=tmp_data_3)

self.ms = dpdata.MultiSystems(
system_1_with_params, system_2_with_params, system_3_with_params
)

self.ms.to_deepmd_npy_mixed("tmp.deepmd.fparam.aparam")
self.place_holder_ms = dpdata.MultiSystems()
self.place_holder_ms.from_deepmd_npy(
"tmp.deepmd.fparam.aparam", fmt="deepmd/npy"
)
self.systems = dpdata.MultiSystems()
self.systems.from_deepmd_npy_mixed(
"tmp.deepmd.fparam.aparam", fmt="deepmd/npy/mixed"
)

self.ms_1 = self.ms
self.ms_2 = self.systems

mixed_sets = glob("tmp.deepmd.fparam.aparam/*/set.*")
for i in mixed_sets:
self.assertEqual(
os.path.exists(os.path.join(i, "real_atom_types.npy")), True
)

self.system_names = ["C1H4A0B0", "C1H3A0B0", "C1H1A1B2"]
self.system_sizes = {"C1H4A0B0": 1, "C1H3A0B0": 1, "C1H1A1B2": 1}
self.atom_names = ["C", "H", "A", "B"]

def tearDown(self):
if os.path.exists("tmp.deepmd.fparam.aparam"):
shutil.rmtree("tmp.deepmd.fparam.aparam")

def test_len(self):
self.assertEqual(len(self.ms), 3)
self.assertEqual(len(self.systems), 3)

def test_get_nframes(self):
self.assertEqual(self.ms.get_nframes(), 3)
self.assertEqual(self.systems.get_nframes(), 3)

def test_str(self):
self.assertEqual(str(self.ms), "MultiSystems (3 systems containing 3 frames)")
self.assertEqual(
str(self.systems), "MultiSystems (3 systems containing 3 frames)"
)

def test_fparam_exists(self):
for formula in self.system_names:
if formula in self.ms.systems:
self.assertTrue("fparam" in self.ms[formula].data)
if formula in self.systems.systems:
self.assertTrue("fparam" in self.systems[formula].data)

for formula in self.system_names:
if formula in self.ms.systems and formula in self.systems.systems:
np.testing.assert_almost_equal(
self.ms[formula].data["fparam"],
self.systems[formula].data["fparam"],
decimal=self.places,
)

def test_aparam_exists(self):
for formula in self.system_names:
if formula in self.ms.systems:
self.assertTrue("aparam" in self.ms[formula].data)
if formula in self.systems.systems:
self.assertTrue("aparam" in self.systems[formula].data)

for formula in self.system_names:
if formula in self.ms.systems and formula in self.systems.systems:
np.testing.assert_almost_equal(
self.ms[formula].data["aparam"],
self.systems[formula].data["aparam"],
decimal=self.places,
)