diff --git a/gotennet/configs/datamodule/rmd17.yaml b/gotennet/configs/datamodule/rmd17.yaml new file mode 100644 index 0000000..598425d --- /dev/null +++ b/gotennet/configs/datamodule/rmd17.yaml @@ -0,0 +1,33 @@ +_target_: gotennet.datamodules.datamodule.DataModule + +hparams: + dataset: rMD17 + dataset_arg: + dataset_root: ${paths.data_dir}/rmd17 # data_path is specified in config.yaml +# dataset_arg: energy_U0 + derivative: true + split_mode: null + reload: 0 + batch_size: 4 + inference_batch_size: 16 + standardize: true + splits: null + train_size: 950 + val_size: 50 + test_size: null + num_workers: 12 + seed: 1 + output_dir: ${paths.output_dir} + ngpus: 1 + num_nodes: 1 + precision: 32 + task: train + distributed_backend: ddp + redirect: false + accelerator: gpu + test_interval: 1500 + save_interval: 1 + prior_model: null + normalize_positions: false + + diff --git a/gotennet/configs/experiment/rmd17_aspirin.yaml b/gotennet/configs/experiment/rmd17_aspirin.yaml new file mode 100644 index 0000000..c94822b --- /dev/null +++ b/gotennet/configs/experiment/rmd17_aspirin.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /datamodule: rmd17.yaml + - override /model: gotennet.yaml + - override /callbacks: default.yaml + - override /logger: wandb.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - override /trainer: default.yaml + +datamodule: + hparams: + seed: 1 + standardize: true + splits: 0 + inference_batch_size: 4 + +label: "aspirin" +model: + lr: 0.0002 + lr_warmup_steps: 1000 + lr_monitor: "validation/val_loss" + lr_minlr: 1.e-07 + lr_patience: 30 + weight_decay: 0.0 + use_ema: True + task_config: + task_loss: "MSELoss" + ema_rates: + - 0.05 + - 1.00 + loss_weights: + - 0.05 + - 0.95 + + representation: + n_interactions: 16 + n_atom_basis: 192 + radial_basis: "expnorm" + n_rbf: 32 + edge_updates: "linw_mlpa" + lmax: 2 + evec_dim: 768 + emlp_dim: 768 + + output: + n_hidden: 256 + + +trainer: + max_epochs: 3000 + inference_mode: False + +callbacks: + early_stopping: + monitor: "validation/val_loss_og" + patience: 1000 + model_checkpoint: + monitor: "validation/val_loss_og" +project: "gotennet_rmd17" + +task: "rMD17" \ No newline at end of file diff --git a/gotennet/datamodules/components/rmd17.py b/gotennet/datamodules/components/rmd17.py new file mode 100644 index 0000000..b08e1be --- /dev/null +++ b/gotennet/datamodules/components/rmd17.py @@ -0,0 +1,106 @@ +import os +import os.path as osp + +import numpy as np +import torch +from pytorch_lightning.utilities import rank_zero_warn +from torch_geometric.data import Data, InMemoryDataset, download_url, extract_tar +from tqdm import tqdm + + +class rMD17(InMemoryDataset): + revised_url = ('https://figshare.com/ndownloader/files/23950376') + + molecule_files = dict( + aspirin='rmd17_aspirin.npz', + azobenzene='rmd17_azobenzene.npz', + benzene='rmd17_benzene.npz', + ethanol='rmd17_ethanol.npz', + malonaldehyde='rmd17_malonaldehyde.npz', + naphthalene='rmd17_naphthalene.npz', + paracetamol='rmd17_paracetamol.npz', + salicylic='rmd17_salicylic.npz', + toluene='rmd17_toluene.npz', + uracil='rmd17_uracil.npz', + ) + + available_molecules = list(molecule_files.keys()) + + def __init__(self, root, transform=None, pre_transform=None, dataset_arg=None): + assert dataset_arg is not None, ( + "Please provide the desired comma separated molecule(s) through" + f"'dataset_arg'. Available molecules are {', '.join(rMD17.available_molecules)} " + "or 'all' to train on the combined dataset." + ) + + if dataset_arg == "all": + dataset_arg = ",".join(rMD17.available_molecules) + self.molecules = dataset_arg.split(",") + + if len(self.molecules) > 1: + rank_zero_warn( + "MD17 molecules have different reference energies, " + "which is not accounted for during training." + ) + + super(rMD17, self).__init__(root, transform, pre_transform) + + self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False) + + def len(self): + return self.data.y.size(0) + + @property + def raw_file_names(self): + return [osp.join('rmd17', 'npz_data', rMD17.molecule_files[mol]) for mol in self.molecules] + + def get_split(self, idx): + assert idx in [0, 1, 2, 3, 4] + + sets = ['index_train', 'index_test'] + + out = [] + for set_name in sets: + split_path = osp.join(self.root, 'raw', 'rmd17', 'splits', set_name+f'_0{idx+1}.csv') + # check file exists + if not osp.exists(split_path): + raise FileNotFoundError(f"File {split_path} not found") + # load csv + with open(split_path, 'r') as f: + split = [int(line.strip()) for line in f.readlines()] + out.append(split) + return out + + @property + def processed_file_names(self): + return [f"rmd17-{mol}.pt" for mol in self.molecules] + + def download(self): + path = download_url(self.revised_url, self.raw_dir) + extract_tar(path, self.raw_dir, mode='r:bz2') + os.unlink(path) + + def process(self): + for path, processed_path in zip(self.raw_paths, self.processed_paths, strict=False): + data_npz = np.load(path) + z = torch.from_numpy(data_npz["nuclear_charges"]).long() + positions = torch.from_numpy(data_npz["coords"]).float() + energies = torch.from_numpy(data_npz["energies"]).float() + forces = torch.from_numpy(data_npz["forces"]).float() + energies.unsqueeze_(1) + + samples = [] + for pos, y, dy in tqdm(zip(positions, energies, forces, strict=False), total=energies.size(0)): + + data = Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy) + + if self.pre_filter is not None: + data = self.pre_filter(data) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + samples.append(data) + + data, slices = self.collate(samples) + torch.save((data, slices), processed_path) diff --git a/gotennet/datamodules/datamodule.py b/gotennet/datamodules/datamodule.py index 6459fff..9ed250b 100644 --- a/gotennet/datamodules/datamodule.py +++ b/gotennet/datamodules/datamodule.py @@ -11,6 +11,7 @@ from gotennet import utils from .components.qm9 import QM9 +from .components.rmd17 import rMD17 from .components.utils import MissingLabelException, make_splits log = utils.get_logger(__name__) @@ -300,3 +301,43 @@ def _prepare_QM9(self): ) return idx_train, idx_val, idx_test + + def _prepare_rMD17(self): + self.dataset = rMD17(root=self.hparams["dataset_root"], dataset_arg=self.hparams["dataset_arg"]) + + train_size = self.hparams["train_size"] + val_size = self.hparams["val_size"] + + splits = self.hparams.get("splits", None) + if splits is not None: + split_idxs = self.dataset.get_split(splits) + assert len(split_idxs) == 2, "Expected two splits" + assert len(split_idxs[ + 0]) == train_size + val_size, f"Expected train+val{train_size + val_size} size != {len(split_idxs[0])} size" + # idx_train, idx_val = split_idxs[0][:train_size], split_idxs[0][train_size:] + idx_train_local, idx_val_local, _ = make_splits( + len(split_idxs[0]), + train_size, + val_size, + None, + self.hparams["seed"], + join(self.hparams["output_dir"], "splits.npz"), + splits=None, + ) + train_val = torch.tensor(split_idxs[0]) + idx_train = train_val[idx_train_local] + idx_val = train_val[idx_val_local] + idx_test = split_idxs[1] + print(f"[ID: {splits}] train {len(idx_train)}, val {len(idx_val)}, test {len(idx_test)}") + else: + idx_train, idx_val, idx_test = make_splits( + len(self.dataset), + train_size, + val_size, + None, + self.hparams["seed"], + join(self.hparams["output_dir"], "splits.npz"), + self.hparams["splits"], + ) + + return idx_train, idx_val, idx_test diff --git a/gotennet/models/components/layers.py b/gotennet/models/components/layers.py index 2bab2b2..e37a804 100644 --- a/gotennet/models/components/layers.py +++ b/gotennet/models/components/layers.py @@ -878,7 +878,7 @@ def forward(self, pos, batch): if self.loop: mask = edge_index[0] != edge_index[1] - edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device) + edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype) edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) else: edge_weight = torch.norm(edge_vec, dim=-1) diff --git a/gotennet/models/tasks/MDTask.py b/gotennet/models/tasks/MDTask.py new file mode 100644 index 0000000..4879863 --- /dev/null +++ b/gotennet/models/tasks/MDTask.py @@ -0,0 +1,156 @@ +"""MD17 task implementation for molecular dynamics simulations.""" + +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn.functional as F +import torchmetrics +from torch.nn import MSELoss + +from gotennet.models.components.outputs import AtomwiseV3 +from gotennet.models.tasks.Task import Task + + +class MDTask(Task): + """ + Task for MD17 molecular dynamics dataset. + + This task predicts energy and forces for molecular dynamics simulations. + """ + + name = "MD17" + + def __init__( + self, + representation, + label_key, + dataset_meta, + task_config=None, + **kwargs + ): + """ + Initialize the MD17 task. + + Args: + representation: The representation model to use. + label_key: The key for the label in the dataset. + dataset_meta: Metadata about the dataset. + task_config (dict, optional): Configuration for the task. Defaults to None. + **kwargs: Additional keyword arguments. + """ + task_defaults = { + "loss_weights": [0.05, 0.95], + } + + super().__init__( + representation, + label_key, + dataset_meta, + task_config=task_config, + task_defaults=task_defaults, + **kwargs + ) + self.num_classes = 1 + + def get_losses(self): + """ + Get the loss functions for the MD17 task. + + Returns: + list: A list of dictionaries containing loss function configurations for energy and force. + """ + ema_rates = self.config.get('ema_rates', [0.05, 1.00]) + print(ema_rates, 'ema_rates', self.config['loss_weights'], 'loss_weights') + if len(ema_rates) == 1: + ema_rates.append(1.00) + + return [ + { + "metric": MSELoss, + "prediction": 'energy', + "target": 'y', + 'ema_rate': ema_rates[0], + "loss_weight": self.config['loss_weights'][0] + }, + { + "metric": MSELoss, + "prediction": 'force', + "target": 'dy', + 'ema_rate': ema_rates[1], + "loss_weight": self.config['loss_weights'][1] + } + ] + + def get_metrics(self): + """ + Get the metrics for the MD17 task. + + Returns: + list: A list of dictionaries containing metric configurations for energy and force. + """ + return [ + { + "metric": torchmetrics.MeanSquaredError, + "prediction": 'energy', + "target": 'y', + }, + { + "metric": torchmetrics.MeanAbsoluteError, + "prediction": 'energy', + "target": 'y', + }, + { + "metric": torchmetrics.MeanSquaredError, + "prediction": 'force', + "target": 'dy', + }, + { + "metric": torchmetrics.MeanAbsoluteError, + "prediction": 'force', + "target": 'dy', + } + ] + + def get_output(self, output_config=None): + """ + Get the output module for the MD17 task. + + Args: + output_config: Configuration for the output module. + + Returns: + torch.nn.ModuleList: A list containing the output module. + """ + outputs = AtomwiseV3( + n_in=self.representation.hidden_dim, + mean=self.dataset_meta['mean'], + stddev=self.dataset_meta['std'], + atomref=None, + aggregation_mode="sum", + property="energy", + activation=F.silu, + derivative="force", + negative_dr=True, + create_graph=True, + **output_config + ) + outputs = [outputs] + return torch.nn.ModuleList(outputs) + + def get_evaluator(self): + """ + Get the evaluator for the MD17 task. + + Returns: + None: No special evaluator is needed for this task. + """ + return None + + def get_dataloader_map(self): + """ + Get the dataloader map for the MD17 task. + + Returns: + list: A list containing 'test' as the only dataloader to use. + """ + return ['test'] diff --git a/gotennet/models/tasks/__init__.py b/gotennet/models/tasks/__init__.py index 2191960..49dda24 100644 --- a/gotennet/models/tasks/__init__.py +++ b/gotennet/models/tasks/__init__.py @@ -3,8 +3,10 @@ from __future__ import absolute_import, division, print_function from gotennet.models.tasks.QM9Task import QM9Task +from gotennet.models.tasks.MDTask import MDTask # Dictionary mapping task names to their implementations TASK_DICT = { 'QM9': QM9Task, # QM9 quantum chemistry dataset + 'rMD17': MDTask, # Revised MD17 dataset } diff --git a/gotennet/utils/utils.py b/gotennet/utils/utils.py index 1d071c4..a95b91c 100644 --- a/gotennet/utils/utils.py +++ b/gotennet/utils/utils.py @@ -32,15 +32,14 @@ def find_config_directory() -> str: # Define search paths in order of preference search_paths = [ - os.path.join(current_dir, "configs"), # Check for configs in CWD os.path.join( current_dir, "gotennet", "configs" ), # Check for gotennet/configs in CWD (e.g. running from project root) os.path.abspath( os.path.join(package_location, "..", "configs") ), # Check for ../configs relative to utils.py (i.e. gotennet/configs) + os.path.join(current_dir, "configs"), # Check for configs in CWD ] - # Search for configs directory for path in search_paths: if os.path.exists(path) and os.path.isdir(path):