Skip to content

Commit 20c5de6

Browse files
Merge pull request #192 from guillemsimeon/main
Inclusion of MD22 dataset
2 parents 3bfaf26 + f06f32d commit 20c5de6

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

torchmdnet/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .custom import Custom
55
from .hdf import HDF5
66
from .md17 import MD17
7+
from .md22 import MD22
78
from .qm9 import QM9
89
from .qm9q import QM9q
910
from .spice import SPICE
@@ -21,6 +22,7 @@
2122
"GDB10to13",
2223
"HDF5",
2324
"MD17",
25+
"MD22",
2426
"QM9",
2527
"QM9q",
2628
"S66X8",

torchmdnet/datasets/md22.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
import os.path as osp
3+
from typing import Callable, List, Optional
4+
import numpy as np
5+
import torch
6+
from torch_geometric.data import (
7+
Data,
8+
InMemoryDataset,
9+
download_url,
10+
extract_tar,
11+
extract_zip,
12+
)
13+
14+
class MD22(InMemoryDataset):
15+
16+
gdml_url = 'http://quantum-machine.org/gdml/data/npz'
17+
18+
file_names = {
19+
'AT-AT-CG-CG': 'md22_AT-AT-CG-CG.npz',
20+
'AT-AT': 'md22_AT-AT.npz',
21+
'Ac-Ala3-NHMe': 'md22_Ac-Ala3-NHMe.npz',
22+
'DHA': 'md22_DHA.npz',
23+
'buckyball-catcher': 'md22_buckyball-catcher.npz',
24+
'dw-nanotube': 'md22_dw_nanotube.npz',
25+
'stachyose': 'md22_stachyose.npz',
26+
}
27+
28+
def __init__(
29+
self,
30+
root: str,
31+
molecules: str,
32+
transform: Optional[Callable] = None,
33+
pre_transform: Optional[Callable] = None,
34+
pre_filter: Optional[Callable] = None,
35+
):
36+
name = molecules
37+
if name not in self.file_names:
38+
raise ValueError(f"Unknown dataset name '{name}'")
39+
40+
self.name = name
41+
42+
super().__init__(root, transform, pre_transform, pre_filter)
43+
44+
idx = 0
45+
self.data, self.slices = torch.load(self.processed_paths[idx])
46+
47+
def mean(self) -> float:
48+
return float(self._data.energy.mean())
49+
50+
@property
51+
def raw_dir(self) -> str:
52+
return osp.join(self.root, self.name, 'raw')
53+
54+
@property
55+
def processed_dir(self) -> str:
56+
return osp.join(self.root, 'processed', self.name)
57+
58+
@property
59+
def raw_file_names(self) -> str:
60+
name = self.file_names[self.name]
61+
return name
62+
63+
@property
64+
def processed_file_names(self) -> List[str]:
65+
return ['data.pt']
66+
67+
def download(self):
68+
url = f'{self.gdml_url}/{self.file_names[self.name]}'
69+
path = download_url(url, self.raw_dir)
70+
71+
def process(self):
72+
it = zip(self.raw_paths, self.processed_paths)
73+
for raw_path, processed_path in it:
74+
raw_data = np.load(raw_path)
75+
76+
z = torch.from_numpy(raw_data['z']).long()
77+
pos = torch.from_numpy(raw_data['R']).float()
78+
energy = torch.from_numpy(raw_data['E']).float()
79+
force = torch.from_numpy(raw_data['F']).float()
80+
81+
data_list = []
82+
for i in range(pos.size(0)):
83+
data = Data(z=z, pos=pos[i], y=energy[i].unsqueeze(-1), neg_dy=force[i])
84+
if self.pre_filter is not None and not self.pre_filter(data):
85+
continue
86+
if self.pre_transform is not None:
87+
data = self.pre_transform(data)
88+
data_list.append(data)
89+
90+
torch.save(self.collate(data_list), processed_path)
91+
92+
def __repr__(self) -> str:
93+
return f"{self.__class__.__name__}({len(self)}, name='{self.name}')"

0 commit comments

Comments
 (0)