Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions gotennet/configs/datamodule/rmd17.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_target_: gotennet.datamodules.datamodule.DataModule

hparams:
dataset: rMD17
dataset_arg:
dataset_root: ${paths.data_dir}/rmd17 # data_path is specified in config.yaml
# dataset_arg: energy_U0
derivative: true
split_mode: null
reload: 0
batch_size: 4
inference_batch_size: 16
standardize: true
splits: null
train_size: 950
val_size: 50
test_size: null
num_workers: 12
seed: 1
output_dir: ${paths.output_dir}
ngpus: 1
num_nodes: 1
precision: 32
task: train
distributed_backend: ddp
redirect: false
accelerator: gpu
test_interval: 1500
save_interval: 1
prior_model: null
normalize_positions: false


64 changes: 64 additions & 0 deletions gotennet/configs/experiment/rmd17_aspirin.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /datamodule: rmd17.yaml
- override /model: gotennet.yaml
- override /callbacks: default.yaml
- override /logger: wandb.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- override /trainer: default.yaml

datamodule:
hparams:
seed: 1
standardize: true
splits: 0
inference_batch_size: 4

label: "aspirin"
model:
lr: 0.0002
lr_warmup_steps: 1000
lr_monitor: "validation/val_loss"
lr_minlr: 1.e-07
lr_patience: 30
weight_decay: 0.0
use_ema: True
task_config:
task_loss: "MSELoss"
ema_rates:
- 0.05
- 1.00
loss_weights:
- 0.05
- 0.95

representation:
n_interactions: 16
n_atom_basis: 192
radial_basis: "expnorm"
n_rbf: 32
edge_updates: "linw_mlpa"
lmax: 2
evec_dim: 768
emlp_dim: 768

output:
n_hidden: 256


trainer:
max_epochs: 3000
inference_mode: False

callbacks:
early_stopping:
monitor: "validation/val_loss_og"
patience: 1000
model_checkpoint:
monitor: "validation/val_loss_og"
project: "gotennet_rmd17"

task: "rMD17"
106 changes: 106 additions & 0 deletions gotennet/datamodules/components/rmd17.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import os.path as osp

import numpy as np
import torch
from pytorch_lightning.utilities import rank_zero_warn
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_tar
from tqdm import tqdm


class rMD17(InMemoryDataset):
revised_url = ('https://figshare.com/ndownloader/files/23950376')

molecule_files = dict(
aspirin='rmd17_aspirin.npz',
azobenzene='rmd17_azobenzene.npz',
benzene='rmd17_benzene.npz',
ethanol='rmd17_ethanol.npz',
malonaldehyde='rmd17_malonaldehyde.npz',
naphthalene='rmd17_naphthalene.npz',
paracetamol='rmd17_paracetamol.npz',
salicylic='rmd17_salicylic.npz',
toluene='rmd17_toluene.npz',
uracil='rmd17_uracil.npz',
)

available_molecules = list(molecule_files.keys())

def __init__(self, root, transform=None, pre_transform=None, dataset_arg=None):
assert dataset_arg is not None, (
"Please provide the desired comma separated molecule(s) through"
f"'dataset_arg'. Available molecules are {', '.join(rMD17.available_molecules)} "
"or 'all' to train on the combined dataset."
)

if dataset_arg == "all":
dataset_arg = ",".join(rMD17.available_molecules)
self.molecules = dataset_arg.split(",")

if len(self.molecules) > 1:
rank_zero_warn(
"MD17 molecules have different reference energies, "
"which is not accounted for during training."
)

super(rMD17, self).__init__(root, transform, pre_transform)

self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

def len(self):
return self.data.y.size(0)

@property
def raw_file_names(self):
return [osp.join('rmd17', 'npz_data', rMD17.molecule_files[mol]) for mol in self.molecules]

def get_split(self, idx):
assert idx in [0, 1, 2, 3, 4]

sets = ['index_train', 'index_test']

out = []
for set_name in sets:
split_path = osp.join(self.root, 'raw', 'rmd17', 'splits', set_name+f'_0{idx+1}.csv')
# check file exists
if not osp.exists(split_path):
raise FileNotFoundError(f"File {split_path} not found")
# load csv
with open(split_path, 'r') as f:
split = [int(line.strip()) for line in f.readlines()]
out.append(split)
return out

@property
def processed_file_names(self):
return [f"rmd17-{mol}.pt" for mol in self.molecules]

def download(self):
path = download_url(self.revised_url, self.raw_dir)
extract_tar(path, self.raw_dir, mode='r:bz2')
os.unlink(path)

def process(self):
for path, processed_path in zip(self.raw_paths, self.processed_paths, strict=False):
data_npz = np.load(path)
z = torch.from_numpy(data_npz["nuclear_charges"]).long()
positions = torch.from_numpy(data_npz["coords"]).float()
energies = torch.from_numpy(data_npz["energies"]).float()
forces = torch.from_numpy(data_npz["forces"]).float()
energies.unsqueeze_(1)

samples = []
for pos, y, dy in tqdm(zip(positions, energies, forces, strict=False), total=energies.size(0)):

data = Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy)

if self.pre_filter is not None:
data = self.pre_filter(data)

if self.pre_transform is not None:
data = self.pre_transform(data)

samples.append(data)

data, slices = self.collate(samples)
torch.save((data, slices), processed_path)
41 changes: 41 additions & 0 deletions gotennet/datamodules/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gotennet import utils

from .components.qm9 import QM9
from .components.rmd17 import rMD17
from .components.utils import MissingLabelException, make_splits

log = utils.get_logger(__name__)
Expand Down Expand Up @@ -300,3 +301,43 @@ def _prepare_QM9(self):
)

return idx_train, idx_val, idx_test

def _prepare_rMD17(self):
self.dataset = rMD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"])

train_size = self.hparams["train_size"]
val_size = self.hparams["val_size"]

splits = self.hparams.get("splits", None)
if splits is not None:
split_idxs = self.dataset.get_split(splits)
assert len(split_idxs) == 2, "Expected two splits"
assert len(split_idxs[
0]) == train_size + val_size, f"Expected train+val{train_size + val_size} size != {len(split_idxs[0])} size"
# idx_train, idx_val = split_idxs[0][:train_size], split_idxs[0][train_size:]
idx_train_local, idx_val_local, _ = make_splits(
len(split_idxs[0]),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["output_dir"], "splits.npz"),
splits=None,
)
train_val = torch.tensor(split_idxs[0])
idx_train = train_val[idx_train_local]
idx_val = train_val[idx_val_local]
idx_test = split_idxs[1]
print(f"[ID: {splits}] train {len(idx_train)}, val {len(idx_val)}, test {len(idx_test)}")
else:
idx_train, idx_val, idx_test = make_splits(
len(self.dataset),
train_size,
val_size,
None,
self.hparams["seed"],
join(self.hparams["output_dir"], "splits.npz"),
self.hparams["splits"],
)

return idx_train, idx_val, idx_test
2 changes: 1 addition & 1 deletion gotennet/models/components/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def forward(self, pos, batch):

if self.loop:
mask = edge_index[0] != edge_index[1]
edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device)
edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype)
edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1)
else:
edge_weight = torch.norm(edge_vec, dim=-1)
Expand Down
Loading