diff --git a/benchmarks/matbench_v0.1_eComFormer/config.py b/benchmarks/matbench_v0.1_eComFormer/config.py new file mode 100644 index 00000000..0b191a30 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/config.py @@ -0,0 +1,195 @@ +"""Pydantic model for default configuration and validation.""" +"""Implementation based on the template of ALIGNN.""" + +import subprocess +from typing import Optional, Union +import os +from pydantic import root_validator + +# vfrom pydantic import Field, root_validator, validator +from pydantic.typing import Literal +from matformer.utils import BaseSettings +from matformer.models.pyg_att import MatformerConfig + +# from typing import List + +try: + VERSION = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() + ) +except Exception as exp: + VERSION = "NA" + pass + + +FEATURESET_SIZE = {"basic": 11, "atomic_number": 1, "cfid": 438, "cgcnn": 92} + + +TARGET_ENUM = Literal[ + "formation_energy_peratom", + "optb88vdw_bandgap", + "bulk_modulus_kv", + "shear_modulus_gv", + "mbj_bandgap", + "slme", + "magmom_oszicar", + "spillage", + "kpoint_length_unit", + "encut", + "optb88vdw_total_energy", + "epsx", + "epsy", + "epsz", + "mepsx", + "mepsy", + "mepsz", + "max_ir_mode", + "min_ir_mode", + "n-Seebeck", + "p-Seebeck", + "n-powerfact", + "p-powerfact", + "ncond", + "pcond", + "nkappa", + "pkappa", + "ehull", + "exfoliation_energy", + "dfpt_piezo_max_dielectric", + "dfpt_piezo_max_eij", + "dfpt_piezo_max_dij", + "gap pbe", + "e_form", + "e_hull", + "energy_per_atom", + "formation_energy_per_atom", + "band_gap", + "e_above_hull", + "mu_b", + "bulk modulus", + "shear modulus", + "elastic anisotropy", + "U0", + "HOMO", + "LUMO", + "R2", + "ZPVE", + "omega1", + "mu", + "alpha", + "homo", + "lumo", + "gap", + "r2", + "zpve", + "U", + "H", + "G", + "Cv", + "A", + "B", + "C", + "all", + "target", + "max_efg", + "avg_elec_mass", + "avg_hole_mass", + "_oqmd_band_gap", + "_oqmd_delta_e", + "_oqmd_stability", + "edos_up", + "pdos_elast", + "bandgap", + "energy_total", + "net_magmom", + "b3lyp_homo", + "b3lyp_lumo", + "b3lyp_gap", + "b3lyp_scharber_pce", + "b3lyp_scharber_voc", + "b3lyp_scharber_jsc", + "log_kd_ki", + "max_co2_adsp", + "min_co2_adsp", + "lcd", + "pld", + "void_fraction", + "surface_area_m2g", + "surface_area_m2cm3", + "indir_gap", + "f_enp", + "final_energy", + "energy_per_atom", + "label", +] + + +class TrainingConfig(BaseSettings): + """Training config defaults and validation.""" + + version: str = VERSION + + # dataset configuration + dataset: Literal[ + "dft_3d", + "megnet", + "mpf", + ] = "dft_3d" + target: TARGET_ENUM = "formation_energy_peratom" + atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn" + neighbor_strategy: Literal["k-nearest", "voronoi", "pairwise-k-nearest"] = "k-nearest" + id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid" + + # logging configuration + + # training configuration + random_seed: Optional[int] = 123 + classification_threshold: Optional[float] = None + n_val: Optional[int] = None + n_test: Optional[int] = None + n_train: Optional[int] = None + train_ratio: Optional[float] = 0.8 + val_ratio: Optional[float] = 0.1 + test_ratio: Optional[float] = 0.1 + target_multiplication_factor: Optional[float] = None + epochs: int = 300 + batch_size: int = 64 + weight_decay: float = 0 + learning_rate: float = 1e-2 + filename: str = "sample" + warmup_steps: int = 2000 + criterion: Literal["mse", "l1", "poisson", "zig"] = "mse" + optimizer: Literal["adamw", "sgd"] = "adamw" + scheduler: Literal["onecycle", "none", "step", "polynomial"] = "onecycle" + pin_memory: bool = False + save_dataloader: bool = False + write_checkpoint: bool = True + write_predictions: bool = True + store_outputs: bool = True + progress: bool = True + log_tensorboard: bool = False + standard_scalar_and_pca: bool = False + use_canonize: bool = True + num_workers: int = 2 + cutoff: float = 4.0 + max_neighbors: int = 12 + keep_data_order: bool = False + distributed: bool = False + n_early_stopping: Optional[int] = None # typically 50 + output_dir: str = os.path.abspath(".") # typically 50 + matrix_input: bool = False + pyg_input: bool = False + use_lattice: bool = False + use_angle: bool = False + + # model configuration + model = MatformerConfig(name="matformer") + print(model) + @root_validator() + def set_input_size(cls, values): + """Automatically configure node feature dimensionality.""" + values["model"].atom_input_features = FEATURESET_SIZE[ + values["atom_features"] + ] + + return values diff --git a/benchmarks/matbench_v0.1_eComFormer/data.py b/benchmarks/matbench_v0.1_eComFormer/data.py new file mode 100644 index 00000000..3e386584 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/data.py @@ -0,0 +1,664 @@ +"""Implementation based on the template of ALIGNN.""" + +import imp +import random +from pathlib import Path +from typing import Optional + +# from typing import Dict, List, Optional, Set, Tuple + +import os +import torch +import numpy as np +import pandas as pd +from jarvis.core.atoms import Atoms +from matformer.graphs import PygGraph, PygStructureDataset +# +from pymatgen.io.jarvis import JarvisAtomsAdaptor +from jarvis.db.figshare import data as jdata +from torch.utils.data import DataLoader +from tqdm import tqdm +import math +from jarvis.db.jsonutils import dumpjson +from pandarallel import pandarallel +pandarallel.initialize(progress_bar=True) +# from sklearn.pipeline import Pipeline +import pickle as pk + +from sklearn.preprocessing import StandardScaler + +# use pandas progress_apply +tqdm.pandas() + + +def load_dataset( + name: str = "dft_3d", + target=None, + limit: Optional[int] = None, + classification_threshold: Optional[float] = None, +): + """Load jarvis data.""" + d = jdata(name) + data = [] + for i in d: + if i[target] != "na" and not math.isnan(i[target]): + if classification_threshold is not None: + if i[target] <= classification_threshold: + i[target] = 0 + elif i[target] > classification_threshold: + i[target] = 1 + else: + raise ValueError( + "Check classification data type.", + i[target], + type(i[target]), + ) + data.append(i) + d = data + if limit is not None: + d = d[:limit] + d = pd.DataFrame(d) + return d + + +def mean_absolute_deviation(data, axis=None): + """Get Mean absolute deviation.""" + return np.mean(np.absolute(data - np.mean(data, axis)), axis) + + + +def load_pyg_graphs( + df: pd.DataFrame, + name: str = "dft_3d", + neighbor_strategy: str = "k-nearest", + cutoff: float = 8, + max_neighbors: int = 12, + cachedir: Optional[Path] = None, + use_canonize: bool = False, + use_lattice: bool = False, + use_angle: bool = False, +): + """Construct crystal graphs. + + Load only atomic number node features + and bond displacement vector edge features. + + Resulting graphs have scheme e.g. + ``` + Graph(num_nodes=12, num_edges=156, + ndata_schemes={'atom_features': Scheme(shape=(1,)} + edata_schemes={'r': Scheme(shape=(3,)}) + ``` + """ + + def atoms_to_graph(atoms): + """Convert structure dict to DGLGraph.""" + adaptor = JarvisAtomsAdaptor() + structure = adaptor.get_atoms(atoms) + return PygGraph.atom_dgl_multigraph( + structure, + neighbor_strategy=neighbor_strategy, + cutoff=cutoff, + atom_features="atomic_number", + max_neighbors=max_neighbors, + compute_line_graph=False, + use_canonize=use_canonize, + use_lattice=use_lattice, + use_angle=use_angle, + ) + + graphs = df["atoms"].parallel_apply(atoms_to_graph).values + # graphs = df["atoms"].apply(atoms_to_graph).values + + return graphs + + +def get_id_train_val_test( + total_size=1000, + split_seed=123, + train_ratio=None, + val_ratio=0.1, + test_ratio=0.1, + n_train=None, + n_test=None, + n_val=None, + keep_data_order=False, +): + """Get train, val, test IDs.""" + if ( + train_ratio is None + and val_ratio is not None + and test_ratio is not None + ): + if train_ratio is None: + assert val_ratio + test_ratio < 1 + train_ratio = 1 - val_ratio - test_ratio + print("Using rest of the dataset except the test and val sets.") + else: + assert train_ratio + val_ratio + test_ratio <= 1 + # indices = list(range(total_size)) + if n_train is None: + n_train = int(train_ratio * total_size) + if n_test is None: + n_test = int(test_ratio * total_size) + if n_val is None: + n_val = int(val_ratio * total_size) + ids = list(np.arange(total_size)) + if not keep_data_order: + random.seed(split_seed) + random.shuffle(ids) + if n_train + n_val + n_test > total_size: + raise ValueError( + "Check total number of samples.", + n_train + n_val + n_test, + ">", + total_size, + ) + + id_train = ids[:n_train] + id_val = ids[-(n_val + n_test) : -n_test] # noqa:E203 + id_test = ids[-n_test:] + return id_train, id_val, id_test + + +def get_torch_dataset( + dataset=[], + id_tag="jid", + target="", + neighbor_strategy="", + atom_features="", + use_canonize="", + name="", + line_graph="", + cutoff=8.0, + max_neighbors=12, + classification=False, + output_dir=".", + tmp_name="dataset", +): + """Get Torch Dataset.""" + df = pd.DataFrame(dataset) + # print("df", df) + vals = df[target].values + if target == "shear modulus" or target == "bulk modulus": + val_list = [vals[i].item() for i in range(len(vals))] + vals = val_list + print("data range", np.max(vals), np.min(vals)) + print("data mean and std", np.mean(vals), np.std(vals)) + f = open(os.path.join(output_dir, tmp_name + "_data_range"), "w") + line = "Max=" + str(np.max(vals)) + "\n" + f.write(line) + line = "Min=" + str(np.min(vals)) + "\n" + f.write(line) + f.close() + + graphs = load_graphs( + df, + name=name, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + cutoff=cutoff, + max_neighbors=max_neighbors, + ) + + data = StructureDataset( + df, + graphs, + target=target, + atom_features=atom_features, + line_graph=line_graph, + id_tag=id_tag, + classification=classification, + ) + return data + +def get_pyg_dataset( + dataset=[], + id_tag="jid", + target="", + neighbor_strategy="", + atom_features="", + use_canonize="", + name="", + line_graph="", + cutoff=8.0, + max_neighbors=12, + classification=False, + output_dir=".", + tmp_name="dataset", + use_lattice=False, + use_angle=False, + data_from='Jarvis', + use_save=False, + mean_train=None, + std_train=None, + now=False, # for test +): + """Get pyg Dataset.""" + df = pd.DataFrame(dataset) + # print("df", df) + # neighbor_strategy = "pairwise-k-nearest" + vals = df[target].values + if target == "shear modulus" or target == "bulk modulus": + val_list = [vals[i].item() for i in range(len(vals))] + vals = val_list + output_dir = "./saved_data/" + tmp_name + "test_graph_angle.pkl" # for fast test use + print("data range", np.max(vals), np.min(vals)) + print(output_dir) + print('graphs not saved') + graphs = load_pyg_graphs( + df, + name=name, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + cutoff=cutoff, + max_neighbors=max_neighbors, + use_lattice=use_lattice, + use_angle=use_angle, + ) + if mean_train == None: + mean_train = np.mean(vals) + std_train = np.std(vals) + data = PygStructureDataset( + df, + graphs, + target=target, + atom_features=atom_features, + line_graph=line_graph, + id_tag=id_tag, + classification=classification, + neighbor_strategy=neighbor_strategy, + mean_train=mean_train, + std_train=std_train, + ) + else: + data = PygStructureDataset( + df, + graphs, + target=target, + atom_features=atom_features, + line_graph=line_graph, + id_tag=id_tag, + classification=classification, + neighbor_strategy=neighbor_strategy, + mean_train=mean_train, + std_train=std_train, + ) + return data, mean_train, std_train + + +def get_train_val_loaders( + dataset: str = "dft_3d", + dataset_array=[], + target: str = "formation_energy_peratom", + atom_features: str = "cgcnn", + neighbor_strategy: str = "k-nearest", + n_train=None, + n_val=None, + n_test=None, + train_ratio=None, + val_ratio=0.1, + test_ratio=0.1, + batch_size: int = 5, + standardize: bool = False, + line_graph: bool = True, + split_seed: int = 123, + workers: int = 0, + pin_memory: bool = True, + save_dataloader: bool = False, + filename: str = "sample", + id_tag: str = "jid", + use_canonize: bool = False, + cutoff: float = 8.0, + max_neighbors: int = 12, + classification_threshold: Optional[float] = None, + target_multiplication_factor: Optional[float] = None, + standard_scalar_and_pca=False, + keep_data_order=False, + output_features=1, + output_dir=None, + matrix_input=False, + pyg_input=False, + use_lattice=False, + use_angle=False, + use_save=True, + mp_id_list=None, + train_inputs=None, + train_outputs=None, + test_inputs=None, + test_outputs=None, +): + """Help function to set up JARVIS train and val dataloaders.""" + # data loading + mean_train=None + std_train=None + assert (matrix_input and pyg_input) == False + + # train_sample = filename + "_train.data" + # val_sample = filename + "_val.data" + # test_sample = filename + "_test.data" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # if ( + # os.path.exists(train_sample) + # and os.path.exists(val_sample) + # and os.path.exists(test_sample) + # and save_dataloader + # ): + # print("Loading from saved file...") + # print("Make sure all the DataLoader params are same.") + # print("This module is made for debugging only.") + # train_loader = torch.load(train_sample) + # val_loader = torch.load(val_sample) + # test_loader = torch.load(test_sample) + # if train_loader.pin_memory != pin_memory: + # train_loader.pin_memory = pin_memory + # if test_loader.pin_memory != pin_memory: + # test_loader.pin_memory = pin_memory + # if val_loader.pin_memory != pin_memory: + # val_loader.pin_memory = pin_memory + # if train_loader.num_workers != workers: + # train_loader.num_workers = workers + # if test_loader.num_workers != workers: + # test_loader.num_workers = workers + # if val_loader.num_workers != workers: + # val_loader.num_workers = workers + # print("train", len(train_loader.dataset)) + # print("val", len(val_loader.dataset)) + # print("test", len(test_loader.dataset)) + # return ( + # train_loader, + # val_loader, + # test_loader, + # train_loader.dataset.prepare_batch, + # ) + # else: + # if not dataset_array: + # d = jdata(dataset) + # else: + # d = dataset_array + # # for ii, i in enumerate(pc_y): + # # d[ii][target] = pc_y[ii].tolist() + + # dat = [] + # if classification_threshold is not None: + # print( + # "Using ", + # classification_threshold, + # " for classifying ", + # target, + # " data.", + # ) + # print("Converting target data into 1 and 0.") + # all_targets = [] + + # # TODO:make an all key in qm9_dgl + # if dataset == "qm9_dgl" and target == "all": + # print("Making all qm9_dgl") + # tmp = [] + # for ii in d: + # ii["all"] = [ + # ii["mu"], + # ii["alpha"], + # ii["homo"], + # ii["lumo"], + # ii["gap"], + # ii["r2"], + # ii["zpve"], + # ii["U0"], + # ii["U"], + # ii["H"], + # ii["G"], + # ii["Cv"], + # ] + # tmp.append(ii) + # print("Made all qm9_dgl") + # d = tmp + # for i in d: + # if isinstance(i[target], list): # multioutput target + # all_targets.append(torch.tensor(i[target])) + # dat.append(i) + + # elif ( + # i[target] is not None + # and i[target] != "na" + # and not math.isnan(i[target]) + # ): + # if target_multiplication_factor is not None: + # i[target] = i[target] * target_multiplication_factor + # if classification_threshold is not None: + # if i[target] <= classification_threshold: + # i[target] = 0 + # elif i[target] > classification_threshold: + # i[target] = 1 + # else: + # raise ValueError( + # "Check classification data type.", + # i[target], + # type(i[target]), + # ) + # dat.append(i) + # all_targets.append(i[target]) + + + # if mp_id_list is not None: + # if mp_id_list == 'bulk': + # print('using mp bulk dataset') + # with open('/data/keqiangyan/bulk_shear/bulk_megnet_train.pkl', 'rb') as f: + # dataset_train = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/bulk_megnet_val.pkl', 'rb') as f: + # dataset_val = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/bulk_megnet_test.pkl', 'rb') as f: + # dataset_test = pk.load(f) + + # if mp_id_list == 'shear': + # print('using mp shear dataset') + # with open('/data/keqiangyan/bulk_shear/shear_megnet_train.pkl', 'rb') as f: + # dataset_train = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/shear_megnet_val.pkl', 'rb') as f: + # dataset_val = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/shear_megnet_test.pkl', 'rb') as f: + # dataset_test = pk.load(f) + + # else: + # id_train, id_val, id_test = get_id_train_val_test( + # total_size=len(dat), + # split_seed=split_seed, + # train_ratio=train_ratio, + # val_ratio=val_ratio, + # test_ratio=test_ratio, + # n_train=n_train, + # n_test=n_test, + # n_val=n_val, + # keep_data_order=keep_data_order, + # ) + # ids_train_val_test = {} + # ids_train_val_test["id_train"] = [dat[i][id_tag] for i in id_train] + # ids_train_val_test["id_val"] = [dat[i][id_tag] for i in id_val] + # ids_train_val_test["id_test"] = [dat[i][id_tag] for i in id_test] + # dumpjson( + # data=ids_train_val_test, + # filename=os.path.join(output_dir, "ids_train_val_test.json"), + # ) + # dataset_train = [dat[x] for x in id_train] + # dataset_val = [dat[x] for x in id_val] + # dataset_test = [dat[x] for x in id_test] + + dataset_train = [] + dataset_val = [] + dataset_test = [] + for i in range(len(train_inputs)): + dataset_train.append({"atoms":train_inputs[i], "label": train_outputs[i]}) + + for i in range(len(test_inputs)): + dataset_val.append({"atoms":test_inputs[i], "label": test_outputs[i]}) + dataset_test.append({"atoms":test_inputs[i], "label": test_outputs[i]}) + + print("Number of train data: ", len(dataset_train)) + print("Number of test data: ", len(dataset_test)) + + # import pdb; pdb.set_trace() + + # if standard_scalar_and_pca: + # y_data = [i[target] for i in dataset_train] + # # pipe = Pipeline([('scale', StandardScaler())]) + # if not isinstance(y_data[0], list): + # print("Running StandardScalar") + # y_data = np.array(y_data).reshape(-1, 1) + # sc = StandardScaler() + + # sc.fit(y_data) + # print("Mean", sc.mean_) + # print("Variance", sc.var_) + # try: + # print("New max", max(y_data)) + # print("New min", min(y_data)) + # except Exception as exp: + # print(exp) + # pass + + # pk.dump(sc, open(os.path.join(output_dir, "sc.pkl"), "wb")) + + # if classification_threshold is None: + # try: + # from sklearn.metrics import mean_absolute_error + + # print("MAX val:", max(all_targets)) + # print("MIN val:", min(all_targets)) + # print("MAD:", mean_absolute_deviation(all_targets)) + # try: + # f = open(os.path.join(output_dir, "mad"), "w") + # line = "MAX val:" + str(max(all_targets)) + "\n" + # line += "MIN val:" + str(min(all_targets)) + "\n" + # line += ( + # "MAD val:" + # + str(mean_absolute_deviation(all_targets)) + # + "\n" + # ) + # f.write(line) + # f.close() + # except Exception as exp: + # print("Cannot write mad", exp) + # pass + # # Random model precited value + # x_bar = np.mean(np.array([i[target] for i in dataset_train])) + # baseline_mae = mean_absolute_error( + # np.array([i[target] for i in dataset_test]), + # np.array([x_bar for i in dataset_test]), + # ) + # print("Baseline MAE:", baseline_mae) + # except Exception as exp: + # print("Data error", exp) + # pass + + train_data, mean_train, std_train = get_pyg_dataset( + dataset=dataset_train, + id_tag=id_tag, + atom_features=atom_features, + target=target, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + name=dataset, + line_graph=line_graph, + cutoff=cutoff, + max_neighbors=max_neighbors, + classification=classification_threshold is not None, + output_dir=output_dir, + tmp_name="train_data", + use_lattice=use_lattice, + use_angle=use_angle, + use_save=False, + ) + # import pdb; pdb.set_trace() + val_data,_,_ = get_pyg_dataset( + dataset=dataset_val, + id_tag=id_tag, + atom_features=atom_features, + target=target, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + name=dataset, + line_graph=line_graph, + cutoff=cutoff, + max_neighbors=max_neighbors, + classification=classification_threshold is not None, + output_dir=output_dir, + tmp_name="val_data", + use_lattice=use_lattice, + use_angle=use_angle, + use_save=False, + mean_train=mean_train, + std_train=std_train, + ) + test_data,_,_ = get_pyg_dataset( + dataset=dataset_test, + id_tag=id_tag, + atom_features=atom_features, + target=target, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + name=dataset, + line_graph=line_graph, + cutoff=cutoff, + max_neighbors=max_neighbors, + classification=classification_threshold is not None, + output_dir=output_dir, + tmp_name="test_data", + use_lattice=use_lattice, + use_angle=use_angle, + use_save=False, + mean_train=mean_train, + std_train=std_train, + ) + + collate_fn = train_data.collate + if line_graph: + collate_fn = train_data.collate_line_graph + + # use a regular pytorch dataloader + train_loader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn, + drop_last=True, + num_workers=workers, + pin_memory=pin_memory, + ) + + val_loader = DataLoader( + val_data, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn, + drop_last=True, + num_workers=workers, + pin_memory=pin_memory, + ) + + test_loader = DataLoader( + test_data, + batch_size=1, + shuffle=False, + collate_fn=collate_fn, + drop_last=False, + num_workers=workers, + pin_memory=pin_memory, + ) + if save_dataloader: + torch.save(train_loader, train_sample) + torch.save(val_loader, val_sample) + torch.save(test_loader, test_sample) + + print("n_train:", len(train_loader.dataset)) + print("n_val:", len(val_loader.dataset)) + print("n_test:", len(test_loader.dataset)) + return ( + train_loader, + val_loader, + test_loader, + train_loader.dataset.prepare_batch, + mean_train, + std_train, + ) + diff --git a/benchmarks/matbench_v0.1_eComFormer/features.py b/benchmarks/matbench_v0.1_eComFormer/features.py new file mode 100644 index 00000000..8e4edcb7 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/features.py @@ -0,0 +1,265 @@ +# Based on the code from: https://github.com/klicperajo/dimenet, +# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet_utils.py + + +import numpy as np +from scipy.optimize import brentq +from scipy import special as sp +import torch +from math import pi as PI + +import sympy as sym + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def Jn(r, n): + return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) + + +def Jn_zeros(n, k): + zerosj = np.zeros((n, k), dtype='float32') + zerosj[0] = np.arange(1, k + 1) * np.pi + points = np.arange(1, k + n) * np.pi + racines = np.zeros(k + n - 1, dtype='float32') + for i in range(1, n): + for j in range(k + n - 1 - i): + foo = brentq(Jn, points[j], points[j + 1], (i, )) + racines[j] = foo + points = racines + zerosj[i][:k] = racines[:k] + + return zerosj + + +def spherical_bessel_formulas(n): + x = sym.symbols('x') + + f = [sym.sin(x) / x] + a = sym.sin(x) / x + for i in range(1, n): + b = sym.diff(a, x) / x + f += [sym.simplify(b * (-x)**i)] + a = sym.simplify(b) + return f + + +def bessel_basis(n, k): + zeros = Jn_zeros(n, k) + normalizer = [] + for order in range(n): + normalizer_tmp = [] + for i in range(k): + normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] + normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 + normalizer += [normalizer_tmp] + + f = spherical_bessel_formulas(n) + x = sym.symbols('x') + bess_basis = [] + for order in range(n): + bess_basis_tmp = [] + for i in range(k): + bess_basis_tmp += [ + sym.simplify(normalizer[order][i] * + f[order].subs(x, zeros[order, i] * x)) + ] + bess_basis += [bess_basis_tmp] + return bess_basis + + +def sph_harm_prefactor(k, m): + return ((2 * k + 1) * np.math.factorial(k - abs(m)) / + (4 * np.pi * np.math.factorial(k + abs(m))))**0.5 + + +def associated_legendre_polynomials(k, zero_m_only=True): + z = sym.symbols('z') + P_l_m = [[0] * (j + 1) for j in range(k)] + + P_l_m[0][0] = 1 + if k > 0: + P_l_m[1][0] = z + + for j in range(2, k): + P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - + (j - 1) * P_l_m[j - 2][0]) / j) + if not zero_m_only: + for i in range(1, k): + P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) + if i + 1 < k: + P_l_m[i + 1][i] = sym.simplify( + (2 * i + 1) * z * P_l_m[i][i]) + for j in range(i + 2, k): + P_l_m[j][i] = sym.simplify( + ((2 * j - 1) * z * P_l_m[j - 1][i] - + (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) + + return P_l_m + + +def real_sph_harm(l, zero_m_only=False, spherical_coordinates=True): + """ + Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). + Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. + """ + if not zero_m_only: + x = sym.symbols('x') + y = sym.symbols('y') + S_m = [x*0] + C_m = [1+0*x] + # S_m = [0] + # C_m = [1] + for i in range(1, l): + x = sym.symbols('x') + y = sym.symbols('y') + S_m += [x*S_m[i-1] + y*C_m[i-1]] + C_m += [x*C_m[i-1] - y*S_m[i-1]] + + P_l_m = associated_legendre_polynomials(l, zero_m_only) + if spherical_coordinates: + theta = sym.symbols('theta') + z = sym.symbols('z') + for i in range(len(P_l_m)): + for j in range(len(P_l_m[i])): + if type(P_l_m[i][j]) != int: + P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) + if not zero_m_only: + phi = sym.symbols('phi') + for i in range(len(S_m)): + S_m[i] = S_m[i].subs(x, sym.sin( + theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) + for i in range(len(C_m)): + C_m[i] = C_m[i].subs(x, sym.sin( + theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) + + Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] + for i in range(l): + Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) + + if not zero_m_only: + for i in range(1, l): + for j in range(1, i + 1): + Y_func_l_m[i][j] = sym.simplify( + 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) + for i in range(1, l): + for j in range(1, i + 1): + Y_func_l_m[i][-j] = sym.simplify( + 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) + + return Y_func_l_m + + +class Envelope(torch.nn.Module): + def __init__(self, exponent): + super(Envelope, self).__init__() + self.p = exponent + 1 + self.a = -(self.p + 1) * (self.p + 2) / 2 + self.b = self.p * (self.p + 2) + self.c = -self.p * (self.p + 1) / 2 + + def forward(self, x): + p, a, b, c = self.p, self.a, self.b, self.c + x_pow_p0 = x.pow(p - 1) + x_pow_p1 = x_pow_p0 * x + x_pow_p2 = x_pow_p1 * x + return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 + + +class dist_emb(torch.nn.Module): + def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5): + super(dist_emb, self).__init__() + self.cutoff = cutoff + self.envelope = Envelope(envelope_exponent) + + self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) + + self.reset_parameters() + + def reset_parameters(self): + torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) + + def forward(self, dist): + dist = dist.unsqueeze(-1) / self.cutoff + return self.envelope(dist) * (self.freq * dist).sin() + + +class angle_emb_mp(torch.nn.Module): + def __init__(self, num_spherical=3, num_radial=30, cutoff=8.0, + envelope_exponent=5): + super(angle_emb_mp, self).__init__() + assert num_radial <= 64 + self.num_spherical = num_spherical + self.num_radial = num_radial + self.cutoff = cutoff + # self.envelope = Envelope(envelope_exponent) + + bessel_forms = bessel_basis(num_spherical, num_radial) + sph_harm_forms = real_sph_harm(num_spherical) + self.sph_funcs = [] + self.bessel_funcs = [] + + x, theta = sym.symbols('x theta') + modules = {'sin': torch.sin, 'cos': torch.cos} + for i in range(num_spherical): + if i == 0: + sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) + self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) + else: + sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) + self.sph_funcs.append(sph) + for j in range(num_radial): + bessel = sym.lambdify([x], bessel_forms[i][j], modules) + self.bessel_funcs.append(bessel) + + def forward(self, dist, angle, idx_kj): + dist = dist / self.cutoff + rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) + # rbf = self.envelope(dist).unsqueeze(-1) * rbf + + cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) + + n, k = self.num_spherical, self.num_radial + out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) + return out + + +class torsion_emb(torch.nn.Module): + def __init__(self, num_spherical, num_radial, cutoff=5.0, + envelope_exponent=5): + super(torsion_emb, self).__init__() + assert num_radial <= 64 + self.num_spherical = num_spherical # + self.num_radial = num_radial + self.cutoff = cutoff + # self.envelope = Envelope(envelope_exponent) + + bessel_forms = bessel_basis(num_spherical, num_radial) + sph_harm_forms = real_sph_harm(num_spherical, zero_m_only=False) + self.sph_funcs = [] + self.bessel_funcs = [] + + x = sym.symbols('x') + theta = sym.symbols('theta') + phi = sym.symbols('phi') + modules = {'sin': torch.sin, 'cos': torch.cos} + for i in range(self.num_spherical): + if i == 0: + sph1 = sym.lambdify([theta, phi], sph_harm_forms[i][0], modules) + self.sph_funcs.append(lambda x, y: torch.zeros_like(x) + torch.zeros_like(y) + sph1(0,0)) #torch.zeros_like(x) + torch.zeros_like(y) + else: + for k in range(-i, i + 1): + sph = sym.lambdify([theta, phi], sph_harm_forms[i][k+i], modules) + self.sph_funcs.append(sph) + for j in range(self.num_radial): + bessel = sym.lambdify([x], bessel_forms[i][j], modules) + self.bessel_funcs.append(bessel) + + def forward(self, dist, angle, phi, idx_kj): + dist = dist / self.cutoff + rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) + cbf = torch.stack([f(angle, phi) for f in self.sph_funcs], dim=1) + + n, k = self.num_spherical, self.num_radial + out = (rbf[idx_kj].view(-1, 1, n, k) * cbf.view(-1, n, n, 1)).view(-1, n * n * k) + return out + diff --git a/benchmarks/matbench_v0.1_eComFormer/graphs.py b/benchmarks/matbench_v0.1_eComFormer/graphs.py new file mode 100644 index 00000000..2032befd --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/graphs.py @@ -0,0 +1,575 @@ +"""Module to generate networkx graphs.""" +"""Implementation based on the template of ALIGNN.""" +from multiprocessing.context import ForkContext +from re import X +import numpy as np +import pandas as pd +from jarvis.core.specie import chem_data, get_node_attributes + +# from jarvis.core.atoms import Atoms +from collections import defaultdict +from typing import List, Tuple, Sequence, Optional +import torch +from torch_geometric.data import Data +from torch_geometric.transforms import LineGraph +from torch_geometric.data.batch import Batch +import itertools + +try: + import torch + from tqdm import tqdm +except Exception as exp: + print("torch/tqdm is not installed.", exp) + pass + + +def angle_from_array(a, b, lattice): + a_new = np.dot(a, lattice) + b_new = np.dot(b, lattice) + assert a_new.shape == a.shape + value = sum(a_new * b_new) + length = (sum(a_new ** 2) ** 0.5) * (sum(b_new ** 2) ** 0.5) + cos = value / length + angle = np.arccos(cos) + return angle / np.pi * 180.0 + +def correct_coord_sys(a, b, c, lattice): + a_new = np.dot(a, lattice) + b_new = np.dot(b, lattice) + c_new = np.dot(c, lattice) + assert a_new.shape == a.shape + plane_vec = np.cross(a_new, b_new) + value = sum(plane_vec * c_new) + length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5) + cos = value / length + angle = np.arccos(cos) + return (angle / np.pi * 180.0 <= 90.0) + +def same_line(a, b): + a_new = a / (sum(a ** 2) ** 0.5) + b_new = b / (sum(b ** 2) ** 0.5) + flag = False + if abs(sum(a_new * b_new) - 1.0) < 1e-5: + flag = True + elif abs(sum(a_new * b_new) + 1.0) < 1e-5: + flag = True + else: + flag = False + return flag + +def same_plane(a, b, c): + flag = False + if abs(np.dot(np.cross(a, b), c)) < 1e-5: + flag = True + return flag + +# pyg dataset +class PygStructureDataset(torch.utils.data.Dataset): + """Dataset of crystal DGLGraphs.""" + + def __init__( + self, + df: pd.DataFrame, + graphs: Sequence[Data], + target: str, + atom_features="atomic_number", + transform=None, + line_graph=False, + classification=False, + id_tag="jid", + neighbor_strategy="", + nolinegraph=False, + mean_train=None, + std_train=None, + ): + """Pytorch Dataset for atomistic graphs. + + `df`: pandas dataframe from e.g. jarvis.db.figshare.data + `graphs`: DGLGraph representations corresponding to rows in `df` + `target`: key for label column in `df` + """ + self.df = df + self.graphs = graphs + self.target = target + self.line_graph = line_graph + + # self.ids = self.df[id_tag] + self.atoms = self.df['atoms'] + self.labels = torch.tensor(self.df[target]).type( + torch.get_default_dtype() + ) + print("mean %f std %f"%(self.labels.mean(), self.labels.std())) + if mean_train == None: + mean = self.labels.mean() + std = self.labels.std() + self.labels = (self.labels - mean) / std + print("normalize using training mean but shall not be used here %f and std %f" % (mean, std)) + else: + self.labels = (self.labels - mean_train) / std_train + print("normalize using training mean %f and std %f" % (mean_train, std_train)) + + self.transform = transform + + features = self._get_attribute_lookup(atom_features) + + # load selected node representation + # assume graphs contain atomic number in g.ndata["atom_features"] + for g in graphs: + z = g.x + g.atomic_number = z + z = z.type(torch.IntTensor).squeeze() + f = torch.tensor(features[z]).type(torch.FloatTensor) + if g.x.size(0) == 1: + f = f.unsqueeze(0) + g.x = f + + self.prepare_batch = prepare_pyg_batch + if line_graph: + self.prepare_batch = prepare_pyg_line_graph_batch + print("building line graphs") + # if not nolinegraph: + # self.line_graphs = [] + # self.graphs = [] + # for g in tqdm(graphs): + # linegraph_trans = LineGraph(force_directed=True) + # g_new = Data() + # g_new.x, g_new.edge_index, g_new.edge_attr, g_new.edge_type = g.x, g.edge_index, g.edge_attr, g.edge_type + # try: + # lg = linegraph_trans(g) + # except Exception as exp: + # print(g.x, g.edge_attr, exp) + # pass + # lg.edge_attr = pyg_compute_bond_cosines(lg) # old cosine emb + # # lg.edge_attr = pyg_compute_bond_angle(lg) + # self.graphs.append(g_new) + # self.line_graphs.append(lg) + # else: + # + self.graphs = [] + for g in tqdm(graphs): + g.edge_attr = g.edge_attr.float() + self.graphs.append(g) + self.line_graphs = self.graphs + + + if classification: + self.labels = self.labels.view(-1).long() + print("Classification dataset.", self.labels) + + @staticmethod + def _get_attribute_lookup(atom_features: str = "cgcnn"): + """Build a lookup array indexed by atomic number.""" + max_z = max(v["Z"] for v in chem_data.values()) + + # get feature shape (referencing Carbon) + template = get_node_attributes("C", atom_features) + + features = np.zeros((1 + max_z, len(template))) + + for element, v in chem_data.items(): + z = v["Z"] + x = get_node_attributes(element, atom_features) + + if x is not None: + features[z, :] = x + + return features + + def __len__(self): + """Get length.""" + return self.labels.shape[0] + + def __getitem__(self, idx): + """Get StructureDataset sample.""" + g = self.graphs[idx] + label = self.labels[idx] + + if self.line_graph: + return g, self.line_graphs[idx], label, label + + return g, label + + def setup_standardizer(self, ids): + """Atom-wise feature standardization transform.""" + x = torch.cat( + [ + g.x + for idx, g in enumerate(self.graphs) + if idx in ids + ] + ) + self.atom_feature_mean = x.mean(0) + self.atom_feature_std = x.std(0) + + self.transform = PygStandardize( + self.atom_feature_mean, self.atom_feature_std + ) + + @staticmethod + def collate(samples: List[Tuple[Data, torch.Tensor]]): + """Dataloader helper to batch graphs cross `samples`.""" + graphs, labels = map(list, zip(*samples)) + batched_graph = Batch.from_data_list(graphs) + return batched_graph, torch.tensor(labels) + + @staticmethod + def collate_line_graph( + samples: List[Tuple[Data, Data, torch.Tensor, torch.Tensor]] + ): + """Dataloader helper to batch graphs cross `samples`.""" + graphs, line_graphs, lattice, labels = map(list, zip(*samples)) + batched_graph = Batch.from_data_list(graphs) + batched_line_graph = Batch.from_data_list(line_graphs) + if len(labels[0].size()) > 0: + return batched_graph, batched_line_graph, torch.cat([i.unsqueeze(0) for i in lattice]), torch.stack(labels) + else: + return batched_graph, batched_line_graph, torch.cat([i.unsqueeze(0) for i in lattice]), torch.tensor(labels) + +def canonize_edge( + src_id, + dst_id, + src_image, + dst_image, +): + """Compute canonical edge representation. + + Sort vertex ids + shift periodic images so the first vertex is in (0,0,0) image + """ + # store directed edges src_id <= dst_id + if dst_id < src_id: + src_id, dst_id = dst_id, src_id + src_image, dst_image = dst_image, src_image + + # shift periodic images so that src is in (0,0,0) image + if not np.array_equal(src_image, (0, 0, 0)): + shift = src_image + src_image = tuple(np.subtract(src_image, shift)) + dst_image = tuple(np.subtract(dst_image, shift)) + + assert src_image == (0, 0, 0) + + return src_id, dst_id, src_image, dst_image + + +def nearest_neighbor_edges_submit( + atoms=None, + cutoff=8, + max_neighbors=12, + id=None, + use_canonize=False, + use_lattice=False, + use_angle=False, +): + """Construct k-NN edge list.""" + # returns List[List[Tuple[site, distance, index, image]]] + lat = atoms.lattice + all_neighbors_now = atoms.get_all_neighbors(r=cutoff) + min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors_now) + + attempt = 0 + if min_nbrs < max_neighbors: + lat = atoms.lattice + if cutoff < max(lat.a, lat.b, lat.c): + r_cut = max(lat.a, lat.b, lat.c) + else: + r_cut = 2 * cutoff + attempt += 1 + return nearest_neighbor_edges_submit( + atoms=atoms, + use_canonize=use_canonize, + cutoff=r_cut, + max_neighbors=max_neighbors, + id=id, + use_lattice=use_lattice, + ) + + edges = defaultdict(set) + # lattice correction process + r_cut = max(lat.a, lat.b, lat.c) + 1e-2 + all_neighbors = atoms.get_all_neighbors(r=r_cut) + neighborlist = all_neighbors[0] + neighborlist = sorted(neighborlist, key=lambda x: x[2]) + ids = np.array([nbr[1] for nbr in neighborlist]) + images = np.array([nbr[3] for nbr in neighborlist]) + images = images[ids == 0] + lat1 = images[0] + # finding lat2 + start = 1 + for i in range(start, len(images)): + lat2 = images[i] + if not same_line(lat1, lat2): + start = i + break + # finding lat3 + for i in range(start, len(images)): + lat3 = images[i] + if not same_plane(lat1, lat2, lat3): + break + # find the invariant corner + if angle_from_array(lat1,lat2,lat.matrix) > 90.0: + lat2 = - lat2 + if angle_from_array(lat1,lat3,lat.matrix) > 90.0: + lat3 = - lat3 + # find the invariant coord system + if not correct_coord_sys(lat1, lat2, lat3, lat.matrix): + lat1 = - lat1 + lat2 = - lat2 + lat3 = - lat3 + + # if not correct_coord_sys(lat1, lat2, lat3, lat.matrix): + # print(lat1, lat2, lat3) + # lattice correction end + for site_idx, neighborlist in enumerate(all_neighbors_now): + + # sort on distance + neighborlist = sorted(neighborlist, key=lambda x: x[2]) + distances = np.array([nbr[2] for nbr in neighborlist]) + ids = np.array([nbr[1] for nbr in neighborlist]) + images = np.array([nbr[3] for nbr in neighborlist]) + + # find the distance to the k-th nearest neighbor + max_dist = distances[max_neighbors - 1] + ids = ids[distances <= max_dist] + images = images[distances <= max_dist] + distances = distances[distances <= max_dist] + for dst, image in zip(ids, images): + src_id, dst_id, src_image, dst_image = canonize_edge( + site_idx, dst, (0, 0, 0), tuple(image) + ) + if use_canonize: + edges[(src_id, dst_id)].add(dst_image) + else: + edges[(site_idx, dst)].add(tuple(image)) + + if use_lattice: + edges[(site_idx, site_idx)].add(tuple(lat1)) + edges[(site_idx, site_idx)].add(tuple(lat2)) + edges[(site_idx, site_idx)].add(tuple(lat3)) + + return edges, lat1, lat2, lat3 + + +def compute_bond_cosine(v1, v2): + """Compute bond angle cosines from bond displacement vectors.""" + v1 = torch.tensor(v1).type(torch.get_default_dtype()) + v2 = torch.tensor(v2).type(torch.get_default_dtype()) + bond_cosine = torch.sum(v1 * v2) / ( + torch.norm(v1) * torch.norm(v2) + ) + bond_cosine = torch.clamp(bond_cosine, -1, 1) + return bond_cosine + + +def build_undirected_edgedata( + atoms=None, + edges={}, + a=None, + b=None, + c=None, +): + """Build undirected graph data from edge set. + + edges: dictionary mapping (src_id, dst_id) to set of dst_image + r: cartesian displacement vector from src -> dst + """ + # second pass: construct *undirected* graph + # import pprint + u, v, r, l, nei, angle, atom_lat = [], [], [], [], [], [], [] + v1, v2, v3 = atoms.lattice.cart_coords(a), atoms.lattice.cart_coords(b), atoms.lattice.cart_coords(c) + atom_lat.append([v1, v2, v3]) + for (src_id, dst_id), images in edges.items(): + + for dst_image in images: + # fractional coordinate for periodic image of dst + dst_coord = atoms.frac_coords[dst_id] + dst_image + # cartesian displacement vector pointing from src -> dst + d = atoms.lattice.cart_coords( + dst_coord - atoms.frac_coords[src_id] + ) + for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]: + u.append(uu) + v.append(vv) + r.append(dd) + nei.append([v1, v2, v3]) + # angle.append([compute_bond_cosine(dd, v1), compute_bond_cosine(dd, v2), compute_bond_cosine(dd, v3)]) + # if np.linalg.norm(d)!=0: + # print ('jv',dst_image,d) + + u = torch.tensor(u) + v = torch.tensor(v) + r = torch.tensor(r).type(torch.get_default_dtype()) + l = torch.tensor(l).type(torch.int) + nei = torch.tensor(nei).type(torch.get_default_dtype()) + atom_lat = torch.tensor(atom_lat).type(torch.get_default_dtype()) + # nei_angles = torch.tensor(angle).type(torch.get_default_dtype()) + return u, v, r, l, nei, atom_lat + + +class PygGraph(object): + """Generate a graph object.""" + + def __init__( + self, + nodes=[], + node_attributes=[], + edges=[], + edge_attributes=[], + color_map=None, + labels=None, + ): + """ + Initialize the graph object. + + Args: + nodes: IDs of the graph nodes as integer array. + + node_attributes: node features as multi-dimensional array. + + edges: connectivity as a (u,v) pair where u is + the source index and v the destination ID. + + edge_attributes: attributes for each connectivity. + as simple as euclidean distances. + """ + self.nodes = nodes + self.node_attributes = node_attributes + self.edges = edges + self.edge_attributes = edge_attributes + self.color_map = color_map + self.labels = labels + + @staticmethod + def atom_dgl_multigraph( + atoms=None, + neighbor_strategy="k-nearest", + cutoff=4.0, + max_neighbors=12, + atom_features="cgcnn", + max_attempts=3, + id: Optional[str] = None, + compute_line_graph: bool = True, + use_canonize: bool = False, + use_lattice: bool = False, + use_angle: bool = False, + ): + if neighbor_strategy == "k-nearest": + edges, a, b, c = nearest_neighbor_edges_submit( + atoms=atoms, + cutoff=cutoff, + max_neighbors=max_neighbors, + id=id, + use_canonize=use_canonize, + use_lattice=use_lattice, + use_angle=use_angle, + ) + u, v, r, l, nei, atom_lat = build_undirected_edgedata(atoms, edges, a, b, c) + elif neighbor_strategy == "pairwise-k-nearest": + u, v, r = pair_nearest_neighbor_edges( + atoms=atoms, + pair_wise_distances=2, + use_lattice=use_lattice, + use_angle=use_angle, + ) + else: + raise ValueError("Not implemented yet", neighbor_strategy) + + + # build up atom attribute tensor + sps_features = [] + for ii, s in enumerate(atoms.elements): + feat = list(get_node_attributes(s, atom_features=atom_features)) + sps_features.append(feat) + sps_features = np.array(sps_features) + node_features = torch.tensor(sps_features).type( + torch.get_default_dtype() + ) + atom_lat = atom_lat.repeat(node_features.shape[0],1,1) + edge_index = torch.cat((u.unsqueeze(0), v.unsqueeze(0)), dim=0).long() + g = Data(x=node_features, edge_index=edge_index, edge_attr=r, edge_type=l, edge_nei=nei, atom_lat=atom_lat) + + return g + +def pyg_compute_bond_cosines(lg): + """Compute bond angle cosines from bond displacement vectors.""" + # line graph edge: (a, b), (b, c) + # `a -> b -> c` + # use law of cosines to compute angles cosines + # negate src bond so displacements are like `a <- b -> c` + # cos(theta) = ba \dot bc / (||ba|| ||bc||) + src, dst = lg.edge_index + x = lg.x + r1 = -x[src] + r2 = x[dst] + bond_cosine = torch.sum(r1 * r2, dim=1) / ( + torch.norm(r1, dim=1) * torch.norm(r2, dim=1) + ) + bond_cosine = torch.clamp(bond_cosine, -1, 1) + return bond_cosine + +def pyg_compute_bond_angle(lg): + """Compute bond angle from bond displacement vectors.""" + # line graph edge: (a, b), (b, c) + # `a -> b -> c` + src, dst = lg.edge_index + x = lg.x + r1 = -x[src] + r2 = x[dst] + a = (r1 * r2).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk| + b = torch.cross(r1, r2).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk| + angle = torch.atan2(b, a) + return angle + + + +class PygStandardize(torch.nn.Module): + """Standardize atom_features: subtract mean and divide by std.""" + + def __init__(self, mean: torch.Tensor, std: torch.Tensor): + """Register featurewise mean and standard deviation.""" + super().__init__() + self.mean = mean + self.std = std + + def forward(self, g: Data): + """Apply standardization to atom_features.""" + h = g.x + g.x = (h - self.mean) / self.std + return g + + + +def prepare_pyg_batch( + batch: Tuple[Data, torch.Tensor], device=None, non_blocking=False +): + """Send batched dgl crystal graph to device.""" + g, t = batch + batch = ( + g.to(device), + t.to(device, non_blocking=non_blocking), + ) + + return batch + + +def prepare_pyg_line_graph_batch( + batch: Tuple[Tuple[Data, Data, torch.Tensor], torch.Tensor], + device=None, + non_blocking=False, +): + """Send line graph batch to device. + + Note: the batch is a nested tuple, with the graph and line graph together + """ + g, lg, lattice, t = batch + batch = ( + ( + g.to(device), + lg.to(device), + lattice.to(device, non_blocking=non_blocking), + ), + t.to(device, non_blocking=non_blocking), + ) + + return batch + diff --git a/benchmarks/matbench_v0.1_eComFormer/info.json b/benchmarks/matbench_v0.1_eComFormer/info.json new file mode 100644 index 00000000..c21ad35d --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/info.json @@ -0,0 +1,14 @@ +{ + "authors": "Keqiang Yan, Cong Fu, Xiaofeng Qian, Xiaoning Qian, Shuiwang Ji", + "algorithm": "eComFormer", + "algorithm_long": "Complete and efficient graph transformer for materials property prediction", + "bibtex_refs": "@inproceedings{ \n yan2024complete, \n title={Complete and Efficient Graph Transformers for Crystal Material Property Prediction},\n author={Keqiang Yan and Cong Fu and Xiaofeng Qian and Xiaoning Qian and Shuiwang Ji},\n booktitle={The Twelfth International Conference on Learning Representations},\n year={2024},\n url={https://openreview.net/forum?id=BnQY9XiRAS}\n}", + "notes": "This is the equivariant version of ComFormer", + "requirements": { + "python": [ + "pytorch==1.13.1", + "torch_geometric==2.3.0", + "matbench==0.1.0, pymatgen=2023.3.23" + ] + } +} \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_eComFormer/models/__init__.py b/benchmarks/matbench_v0.1_eComFormer/models/__init__.py new file mode 100644 index 00000000..ccc4f536 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/models/__init__.py @@ -0,0 +1 @@ +"""Graph neural network implementations.""" diff --git a/benchmarks/matbench_v0.1_eComFormer/models/backup.py b/benchmarks/matbench_v0.1_eComFormer/models/backup.py new file mode 100644 index 00000000..b027ca27 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/models/backup.py @@ -0,0 +1,801 @@ +class MPNNConv(MessagePassing): + """Implements the message passing layer from + `"Crystal Graph Convolutional Neural Networks for an + Accurate and Interpretable Prediction of Material Properties" + `. + """ + + def init(self, fc_features): + super(MPNNConv, self).init(node_dim=0) + self.bn = nn.BatchNorm1d(fc_features) + self.bn_interaction = nn.BatchNorm1d(fc_features) + self.nonlinear_full = nn.Sequential( + nn.Linear(3 * fc_features, fc_features), + nn.SiLU(), + nn.Linear(fc_features, fc_features) + ) + self.nonlinear = nn.Sequential( + nn.Linear(3 * fc_features, fc_features), + nn.SiLU(), + nn.Linear(fc_features, fc_features), + ) + + def forward(self, x, edge_index, edge_attr): + """ + Arguments: + x has shape [num_nodes, node_feat_size] + edge_index has shape [2, num_edges] + edge_attr is [num_edges, edge_feat_size] + """ + + out = self.propagate( + edge_index, x=x, edge_attr=edge_attr, size=(x.size(0), x.size(0)) + ) + + return F.relu(x + self.bn(out)) + + def message(self, x_i, x_j, edge_attr, index): + score = torch.sigmoid(self.bn_interaction(self.nonlinear_full(torch.cat((x_i, x_j, edge_attr), dim=1)))) + return score * self.nonlinear(torch.cat((x_i, x_j, edge_attr), dim=1)) + + + + +############ +# 03/08/2023 +class MatformerConv(MessagePassing): + _alpha: OptTensor + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super(MatformerConv, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + self._alpha = None + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + else: + self.lin_edge = self.register_parameter('lin_edge', None) + + if concat: + self.lin_skip = nn.Linear(in_channels[1], out_channels, + bias=bias) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + if self.beta: + self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + else: + self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_msg_update = nn.Linear(out_channels * 3, out_channels * 3) + self.layer_norm = nn.LayerNorm(out_channels * 3) + self.msg_layer = nn.Sequential(nn.Linear(out_channels * 3, out_channels), nn.LayerNorm(out_channels)) + # simpler version + # self.lin_msg_update = nn.Linear(out_channels * 3, out_channels) + # self.layer_norm = nn.LayerNorm(out_channels) + # self.msg_layer = nn.Sequential(nn.Linear(out_channels, out_channels), nn.LayerNorm(out_channels)) + # self.msg_layer = nn.Linear(out_channels * 3, out_channels) + self.bn = nn.BatchNorm1d(out_channels) + # self.bn = nn.BatchNorm1d(out_channels * heads) + self.sigmoid = nn.Sigmoid() + self.reset_parameters() + + def reset_parameters(self): + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + if self.concat: + self.lin_concate.reset_parameters() + if self.edge_dim: + self.lin_edge.reset_parameters() + self.lin_skip.reset_parameters() + if self.beta: + self.lin_beta.reset_parameters() + + def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, + edge_attr: OptTensor = None, return_attention_weights=None): + + H, C = self.heads, self.out_channels + if isinstance(x, Tensor): + x: PairTensor = (x, x) + + query = self.lin_query(x[1]).view(-1, H, C) + key = self.lin_key(x[0]).view(-1, H, C) + value = self.lin_value(x[0]).view(-1, H, C) + + out = self.propagate(edge_index, query=query, key=key, value=value, + edge_attr=edge_attr, size=None) + alpha = self._alpha + self._alpha = None + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(dim=1) + + if self.concat: + out = self.lin_concate(out) + + out = F.silu(self.bn(out)) # after norm and silu + + if self.root_weight: + x_r = self.lin_skip(x[1]) + if self.lin_beta is not None: + beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) + beta = beta.sigmoid() + out = beta * x_r + (1 - beta) * out + else: + out += x_r + + + if isinstance(return_attention_weights, bool): + assert alpha is not None + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + elif isinstance(edge_index, SparseTensor): + return out, edge_index.set_value(alpha, layout='coo') + else: + return out + + def message(self, query_i: Tensor, key_i: Tensor, key_j: Tensor, value_j: Tensor, value_i: Tensor, + edge_attr: OptTensor, index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: + + if self.lin_edge is not None: + assert edge_attr is not None + edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels) + + query_i = torch.cat((query_i, query_i, query_i), dim=-1) + key_j = torch.cat((key_i, key_j, edge_attr), dim=-1) + alpha = (query_i * key_j) / math.sqrt(self.out_channels * 3) + self._alpha = alpha + alpha = F.dropout(alpha, p=self.dropout, training=self.training) + out = torch.cat((value_i, value_j, edge_attr), dim=-1) + out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, 3 * self.out_channels))) + out = self.msg_layer(out) + + # version two, simpler + # query_i = query_i + # key_j = key_j + # alpha = (query_i * key_j) / math.sqrt(self.out_channels) + # self._alpha = alpha + # out = torch.cat((value_i, value_j, edge_attr), dim=-1) + # out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, self.out_channels))) + # out = self.msg_layer(out) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') + + +class MatformerConv_edge(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + else: + self.lin_edge = self.register_parameter('lin_edge', None) + + if concat: + self.lin_skip = nn.Linear(in_channels[1], out_channels, + bias=bias) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + if self.beta: + self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + else: + self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_msg_update = nn.Linear(out_channels * 3, out_channels * 3) + self.layer_norm = nn.LayerNorm(out_channels * 3) + self.msg_layer = nn.Sequential(nn.Linear(out_channels * 3, out_channels), nn.LayerNorm(out_channels)) + self.bn = nn.BatchNorm1d(out_channels) + self.sigmoid = nn.Sigmoid() + self.reset_parameters() + + def reset_parameters(self): + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + if self.concat: + self.lin_concate.reset_parameters() + if self.edge_dim: + self.lin_edge.reset_parameters() + self.lin_skip.reset_parameters() + if self.beta: + self.lin_beta.reset_parameters() + + def forward(self, edge: Union[Tensor, PairTensor], edge_nei_len: OptTensor = None, edge_nei_angle: OptTensor = None): + # preprocess for edge of shape [num_edges, hidden_dim] + + H, C = self.heads, self.out_channels + if isinstance(edge, Tensor): + edge: PairTensor = (edge, edge) + + query_x = self.lin_query(edge[1]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + key_x = self.lin_key(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + value_x = self.lin_value(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + + # preprocess for edge_neighbor of shape [num_edges, 3, hidden_dim] + query_y = self.lin_query(edge_nei_len).view(-1, 3, H, C) + key_y = self.lin_key(edge_nei_len).view(-1, 3, H, C) + value_y = self.lin_value(edge_nei_len).view(-1, 3, H, C) + + # preprocess for interaction of shape [num_edges, 3, hidden_dim] + edge_xy = self.lin_edge(edge_nei_angle).view(-1, 3, H, C) + + query = torch.cat((query_x, query_x, query_x), dim=-1) + key = torch.cat((key_x, key_y, edge_xy), dim=-1) + alpha = (query * key) / math.sqrt(self.out_channels * 3) + out = torch.cat((value_x, value_y, edge_xy), dim=-1) + out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha)) + out = self.msg_layer(out) + + if self.concat: + out = out.view(-1, 3, self.heads * self.out_channels) + else: + out = out.mean(dim=2) + + if self.concat: + out = self.lin_concate(out) + + # aggregate the msg + out = out.sum(dim=1) + + out = F.silu(self.bn(out)) + + if self.root_weight: + x_r = self.lin_skip(edge[1]) + out += x_r + + return out + + +##################### +# 03/07/2023 +##################### + + +# class MatformerConv_edge(MessagePassing): +# _alpha: OptTensor + +# def __init__( +# self, +# in_channels: Union[int, Tuple[int, int]], +# out_channels: int, +# heads: int = 1, +# concat: bool = True, +# beta: bool = False, +# dropout: float = 0.0, +# edge_dim: Optional[int] = None, +# bias: bool = True, +# root_weight: bool = True, +# **kwargs, +# ): +# kwargs.setdefault('aggr', 'add') +# super(MatformerConv_edge, self).__init__(node_dim=0, **kwargs) + +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.heads = heads +# self.beta = beta and root_weight +# self.root_weight = root_weight +# self.concat = concat +# self.dropout = dropout +# self.edge_dim = edge_dim +# self._alpha = None + +# if isinstance(in_channels, int): +# in_channels = (in_channels, in_channels) + +# self.lin_key = nn.Linear(in_channels[0], heads * out_channels) +# self.lin_query = nn.Linear(in_channels[1], heads * out_channels) +# self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + +# if edge_dim is not None: +# self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) +# else: +# self.lin_edge = self.register_parameter('lin_edge', None) + +# if concat: +# self.lin_skip = nn.Linear(in_channels[1], out_channels, +# bias=bias) +# self.lin_concate = nn.Linear(heads * out_channels, out_channels) +# if self.beta: +# self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=False) +# else: +# self.lin_beta = self.register_parameter('lin_beta', None) +# else: +# self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) +# if self.beta: +# self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) +# else: +# self.lin_beta = self.register_parameter('lin_beta', None) +# self.lin_msg_update = nn.Linear(out_channels * 3, out_channels * 3) +# self.layer_norm = nn.LayerNorm(out_channels * 3) +# self.msg_layer = nn.Sequential(nn.Linear(out_channels * 3, out_channels), nn.LayerNorm(out_channels)) +# # simpler version +# # self.lin_msg_update = nn.Linear(out_channels * 3, out_channels) +# # self.layer_norm = nn.LayerNorm(out_channels) +# # self.msg_layer = nn.Sequential(nn.Linear(out_channels, out_channels), nn.LayerNorm(out_channels)) +# # self.msg_layer = nn.Linear(out_channels * 3, out_channels) +# self.bn = nn.BatchNorm1d(out_channels) +# # self.bn = nn.BatchNorm1d(out_channels * heads) +# self.sigmoid = nn.Sigmoid() +# self.reset_parameters() + +# def reset_parameters(self): +# self.lin_key.reset_parameters() +# self.lin_query.reset_parameters() +# self.lin_value.reset_parameters() +# if self.concat: +# self.lin_concate.reset_parameters() +# if self.edge_dim: +# self.lin_edge.reset_parameters() +# self.lin_skip.reset_parameters() +# if self.beta: +# self.lin_beta.reset_parameters() + +# def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, +# edge_attr: OptTensor = None, return_attention_weights=None): + +# H, C = self.heads, self.out_channels +# if isinstance(x, Tensor): +# x: PairTensor = (x, x) + +# query = self.lin_query(x[1]).view(-1, H, C) +# key = self.lin_key(x[0]).view(-1, H, C) +# value = self.lin_value(x[0]).view(-1, H, C) + +# out = self.propagate(edge_index, query=query, key=key, value=value, +# edge_attr=edge_attr, size=None) +# alpha = self._alpha +# self._alpha = None + +# if self.concat: +# out = out.view(-1, self.heads * self.out_channels) +# else: +# out = out.mean(dim=1) + +# if self.concat: +# out = self.lin_concate(out) + +# out = F.silu(self.bn(out)) # after norm and silu + +# if self.root_weight: +# x_r = self.lin_skip(x[1]) +# if self.lin_beta is not None: +# beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) +# beta = beta.sigmoid() +# out = beta * x_r + (1 - beta) * out +# else: +# out += x_r + + +# if isinstance(return_attention_weights, bool): +# assert alpha is not None +# if isinstance(edge_index, Tensor): +# return out, (edge_index, alpha) +# elif isinstance(edge_index, SparseTensor): +# return out, edge_index.set_value(alpha, layout='coo') +# else: +# return out + +# def message(self, query_i: Tensor, key_i: Tensor, key_j: Tensor, value_j: Tensor, value_i: Tensor, +# edge_attr: OptTensor, index: Tensor, ptr: OptTensor, +# size_i: Optional[int]) -> Tensor: + +# if self.lin_edge is not None: +# assert edge_attr is not None +# edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels) + +# query_i = torch.cat((query_i, query_i, query_i), dim=-1) +# key_j = torch.cat((key_i, key_j, edge_attr), dim=-1) +# alpha = (query_i * key_j) / math.sqrt(self.out_channels * 3) +# self._alpha = alpha +# alpha = F.dropout(alpha, p=self.dropout, training=self.training) +# out = torch.cat((value_i, value_j, edge_attr), dim=-1) +# out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, 3 * self.out_channels))) +# out = self.msg_layer(out) + +# # version two, simpler +# # query_i = query_i +# # key_j = key_j +# # alpha = (query_i * key_j) / math.sqrt(self.out_channels) +# # self._alpha = alpha +# # out = torch.cat((value_i, value_j, edge_attr), dim=-1) +# # out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, self.out_channels))) +# # out = self.msg_layer(out) +# return out + +# def __repr__(self) -> str: +# return (f'{self.__class__.__name__}({self.in_channels}, ' +# f'{self.out_channels}, heads={self.heads})') + + + +##################### +# 03/21/2023 +##################### + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 64, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + if use_second_order_repr: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{out_channels}x0e' + ] + else: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o', + f'{out_channels}x0e', + ] + self.ns = ns + self.nv = nv + self.node_linear = nn.Linear(in_channels, ns) + self.skip_linear = nn.Linear(in_channels, out_channels) + self.v1_v2_linear = nn.Linear(ns, out_channels) + + self.sh = '1x0e + 1x1o + 1x2e' + self.v2_tp = v2_tp = o3.FullyConnectedTensorProduct(f'{ns}x0e + {nv}x1o + {nv}x2e', '1x0e + 1x1o + 1x2e', f'{out_channels}x0e', shared_weights=False) + self.v2_fc = nn.Sequential( + nn.Linear(edge_dim * 3, edge_dim), + nn.Softplus(), + nn.Linear(edge_dim, v2_tp.weight_numel) + ) + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + # MACE + self.softplus = nn.Softplus() + self.ln_0e = nn.Parameter(torch.ones(1, 3, 1)) + self.ln_1o = nn.Parameter(torch.ones(1, 3, 1)) + self.ln_2e = nn.Parameter(torch.ones(1, 3, 1)) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + lat_len: OptTensor = None): + edge_vec = data.edge_attr + n_ = node_feature.shape[0] + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + lat = o3.spherical_harmonics(self.sh, data.atom_lat.view(n_ * 3, 3), normalize=True, normalization='component') + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + irreps = o3.Irreps('1x0e + 1x1o + 1x2e') + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + # the second layer + lat_l0, lat_l1o, lat_l2e = lat[:, irreps.slices()[0]], lat[:, irreps.slices()[1]], lat[:, irreps.slices()[2]] + lat_l0 = (lat_l0.view(n_, 3, 1) * self.ln_0e).sum(dim=1) + lat_l1o = (lat_l1o.view(n_, 3, 3) * self.ln_1o).sum(dim=1) + lat_l2e = (lat_l2e.view(n_, 3, 5) * self.ln_2e).sum(dim=1) + lat_vec = torch.cat((lat_l0, lat_l1o, lat_l2e), dim=-1) + node_v2 = self.v2_tp(node_feature, lat_vec, self.v2_fc(lat_len.view(n_, -1))) + node_v2 = self.softplus(self.bn(node_v2)) + node_v2 += self.skip_linear(skip_connect) + + return node_v2 + + + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 128, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + if use_second_order_repr: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{out_channels}x0e' + ] + else: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o', + f'{out_channels}x0e', + ] + self.ns = ns + self.nv = nv + # for input x mapping + self.node_linear = nn.Linear(in_channels, ns) + # for input x mapping to the output + self.skip_linear = nn.Linear(in_channels, out_channels) + # for l0 mapping to the output + self.v1_v2_linear = nn.Linear(ns, out_channels) + + self.sh = '1x0e + 1x1o + 1x2e' + self.v2_tp = v2_tp = o3.FullyConnectedTensorProduct(f'{ns}x0e + {nv}x1o + {nv}x2e', '1x0e + 1x1o + 1x2e', f'{ns}x0e + {nv}x1o + {nv}x2e', shared_weights=False) + self.v2_fc = nn.Sequential( + nn.Linear(ns, ns), + nn.Softplus(), + nn.Linear(ns, v2_tp.weight_numel) + ) + + self.v2_tp_2 = v2_tp_2 = o3.FullyConnectedTensorProduct(f'{ns}x0e + {nv}x1o + {nv}x2e', '1x0e + 1x1o + 1x2e', f'{out_channels}x0e', shared_weights=False) + self.v2_fc_2 = nn.Sequential( + nn.Linear(ns, ns), + nn.Softplus(), + nn.Linear(ns, v2_tp_2.weight_numel) + ) + + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + # MACE + self.softplus = nn.Softplus() + self.ln_0e = nn.Parameter(torch.ones(1, ns)) + self.ln_1o = nn.Parameter(torch.ones(1, nv, 1)) + self.ln_2e = nn.Parameter(torch.ones(1, nv, 1)) + self.bn = nn.BatchNorm1d(ns) + + self.ln_0e2 = nn.Parameter(torch.ones(1, ns)) + self.ln_1o2 = nn.Parameter(torch.ones(1, nv, 1)) + self.ln_2e2 = nn.Parameter(torch.ones(1, nv, 1)) + self.bn2 = nn.BatchNorm1d(out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + edge_nei_len: OptTensor = None): + edge_vec = data.edge_attr + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + n_ = node_feature.shape[0] + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + ns, nv = self.ns, self.nv + irreps = o3.Irreps(f'{ns}x0e + {nv}x1o + {nv}x2e') + # the first layer + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + # the second layer + node_l0, node_l1o, node_l2e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + node_l1o, node_l2e = node_l1o.reshape(n_, -1, 3), node_l2e.reshape(n_, -1, 5) + node_l0 = self.softplus(node_l0) + node_l0_update = (node_l0 * self.ln_0e).sum(dim=1, keepdim=True) + node_l1o = (node_l1o * self.ln_1o).sum(dim=1, keepdim=True) + node_l2e = (node_l2e * self.ln_2e).sum(dim=1, keepdim=True) + node_feature_vec = torch.cat((node_l0_update, node_l1o.reshape(n_, -1), node_l2e.reshape(n_, -1)), dim=-1) + node_v2 = self.v2_tp(node_feature, node_feature_vec, self.v2_fc(node_l0)) + node_v2_l0 = node_v2[:, irreps.slices()[0]] + node_v2_l0 = node_v2_l0 + node_l0 + node_v2_l0 = self.softplus(self.bn(node_v2_l0)) + node_v2[:, irreps.slices()[0]] = node_v2_l0 + # the first layer + node_feature = self.nlayer_2(node_v2, edge_index, edge_feature, edge_irr) + # the second layer + node_l0, node_l1o, node_l2e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + node_l1o, node_l2e = node_l1o.reshape(n_, -1, 3), node_l2e.reshape(n_, -1, 5) + node_l0 = self.softplus(node_l0) + node_l0_update = (node_l0 * self.ln_0e2).sum(dim=1, keepdim=True) + node_l1o = (node_l1o * self.ln_1o2).sum(dim=1, keepdim=True) + node_l2e = (node_l2e * self.ln_2e2).sum(dim=1, keepdim=True) + node_feature_vec = torch.cat((node_l0_update, node_l1o.reshape(n_, -1), node_l2e.reshape(n_, -1)), dim=-1) + node_v2 = self.v2_tp_2(node_feature, node_feature_vec, self.v2_fc_2(node_l0)) + node_v2 = node_v2 + self.v1_v2_linear(node_l0) + node_v2 = self.softplus(self.bn2(node_v2)) + + node_v2 += self.skip_linear(skip_connect) + return node_v2 + + + + + # edge_nei_vec = data.edge_nei / data.edge_nei.norm(dim=-1, keepdim=True) + # edge_irr = torch.cat((self.edge_tp(edge_vec.unsqueeze(1).repeat(1, 3, 1), edge_nei_vec, self.edge_tp_fc(edge_nei_len)).sum(dim=1), + # edge_vec), dim=-1) + +# nonlinearity and norm of equi features + # ns, nv = self.ns, self.nv + # irreps = o3.Irreps(f'{ns}x0e + {nv}x1o + {nv}x2e') + # node_l0, node_l1o, node_l1e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + # node_l1o, node_l1e = node_l1o.reshape(n_, -1, 3), node_l1e.reshape(n_, -1, 3) + # # for order = 0 part + # node_l0 = self.softplus(node_l0) + # rms_l0 = node_l0.norm(dim=-1, keepdim=True) * (ns ** -0.5) + # node_l0 = node_l0 / rms_l0.clamp(min = 1e-12) * self.ln_0e + # # for order = 1o part + # l2norm = node_l1o.norm(dim=-1, keepdim=True) + # rms_l1o = l2norm.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l1o = node_l1o / rms_l1o.clamp(min = 1e-12) * self.ln_1o + # # for order = 1e part + # l2norme = node_l1e.norm(dim=-1, keepdim=True) + # rms_l1e = l2norme.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l1e = node_l1e / rms_l1e.clamp(min = 1e-12) * self.ln_1e + # node_feature = torch.cat((node_l0, node_l1o.reshape(n_, -1), node_l1e.reshape(n_, -1)), dim=-1) + # the second layer + # node_feature = self.nlayer_2(node_feature, edge_index, edge_feature, edge_irr) + + + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 128, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + # if use_second_order_repr: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{out_channels}x0e' + ] + # else: + # irrep_seq = [ + # f'{ns}x0e', + # f'{ns}x0e + {nv}x1o', + # f'{out_channels}x0e', + # ] + + self.node_linear = nn.Linear(in_channels, ns) + # for input x mapping to the output + self.skip_linear = nn.Linear(in_channels, out_channels) + + self.sh = '1x0e + 1x1o + 1x2e' + + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[2], + n_edge_features=edge_dim, + residual=False + ) + + self.softplus = nn.Softplus() + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + edge_nei_len: OptTensor = None): + edge_vec = data.edge_attr + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + n_ = node_feature.shape[0] + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.nlayer_2(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.softplus(self.bn(node_feature)) + node_feature += self.skip_linear(skip_connect) + + return node_feature + + + + + # ns, nv = self.ns, self.nv + # irreps = o3.Irreps(f'{ns}x0e + {nv}x1o + {nv}x2e') + # node_l0, node_l1o, node_l2e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + # node_l1o, node_l2e = node_l1o.reshape(n_, -1, 3), node_l2e.reshape(n_, -1, 5) + # # for order = 0 part + # node_l0 = self.softplus(node_l0) + # rms_l0 = node_l0.norm(dim=-1, keepdim=True) * (ns ** -0.5) + # node_l0 = node_l0 / rms_l0.clamp(min = 1e-12) * self.ln_0e + # # for order = 1o part + # l2norm = node_l1o.norm(dim=-1, keepdim=True) + # rms_l1o = l2norm.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l1o = node_l1o / rms_l1o.clamp(min = 1e-12) * self.ln_1o + # # for order = 1e part + # l2norme = node_l2e.norm(dim=-1, keepdim=True) + # rms_l2e = l2norme.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l2e = node_l2e / rms_l2e.clamp(min = 1e-12) * self.ln_2e + # node_feature = torch.cat((node_l0, node_l1o.reshape(n_, -1), node_l2e.reshape(n_, -1)), dim=-1) \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_eComFormer/models/bn_utils.py b/benchmarks/matbench_v0.1_eComFormer/models/bn_utils.py new file mode 100644 index 00000000..eec8f27e --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/models/bn_utils.py @@ -0,0 +1,269 @@ +from typing import Optional, Any + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer + +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm +from torch.nn.modules.lazy import LazyModuleMixin +from torch.nn.modules.module import Module + + +class _NormBase(Module): + """Common base of _InstanceNorm and _BatchNorm""" + + _version = 2 + __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] + num_features: int + eps: float + momentum: float + affine: bool + track_running_stats: bool + # WARNING: weight and bias purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(_NormBase, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) + self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) + self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) + self.running_mean: Optional[Tensor] + self.running_var: Optional[Tensor] + self.register_buffer('num_batches_tracked', + torch.tensor(0, dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})) + self.num_batches_tracked: Optional[Tensor] + else: + self.register_buffer("running_mean", None) + self.register_buffer("running_var", None) + self.register_buffer("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[union-attr] + self.running_var.fill_(1) # type: ignore[union-attr] + self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] + + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def _check_input_dim(self, input): + raise NotImplementedError + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) + + super(_NormBase, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class _MaskedBatchNorm(_NormBase): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super(_MaskedBatchNorm, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def forward(self, input: Tensor, mask: Tensor) -> Tensor: + self._check_input_dim(input) + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + result, self.running_mean, self.running_var = batch_norm( + X=input, + running_mean=self.running_mean + if not self.training or self.track_running_stats + else None, + running_var=self.running_var if not self.training or self.track_running_stats else None, + weight=self.weight, + bias=self.bias, + training=bn_training, + momentum=exponential_average_factor, + eps=self.eps, + mask=mask, + ) + return result + +def batch_norm(X, weight, bias, running_mean, running_var, training, momentum, eps, mask): + if not training: + X_hat = (X - running_mean) / torch.sqrt(running_var + eps) + else: + count = mask.sum(dim=0, keepdim=True) + mean = (X * mask).sum(dim=0, keepdim=True) / (count + 1e-5) + var = (((X - mean) ** 2) * mask).sum(dim=0, keepdim=True) / (count + 1e-5) + X_hat = (X - mean) / torch.sqrt(var + eps) + # Update the mean and variance using moving average + running_mean = momentum * running_mean + (1.0 - momentum) * mean + running_var = momentum * running_var + (1.0 - momentum) * var + Y = (weight * X_hat + bias) * mask # Scale and shift + return Y, running_mean.data, running_var.data + + + +class MaskedBatchNorm1d(_MaskedBatchNorm): + r"""Applies Batch Normalization over a 2D or 3D input as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input). By default, the + elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The + standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, + :math:`C` is the number of features or channels, and :math:`L` is the sequence length + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm1d(100, affine=False) + >>> input = torch.randn(20, 100) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError( + "expected 2D or 3D input (got {}D input)".format(input.dim()) + ) + diff --git a/benchmarks/matbench_v0.1_eComFormer/models/pyg_att.py b/benchmarks/matbench_v0.1_eComFormer/models/pyg_att.py new file mode 100644 index 00000000..0ac69300 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/models/pyg_att.py @@ -0,0 +1,263 @@ +"""Implementation based on the template of ALIGNN.""" + +from typing import Tuple +import math +import numpy as np +import torch +import torch.nn.functional as F +from pydantic.typing import Literal +from torch import nn +from matformer.models.utils import RBFExpansion +from matformer.utils import BaseSettings +from matformer.features import angle_emb_mp +from torch_scatter import scatter +from matformer.models.transformer import MatformerConv, MatformerConv_edge, MatformerConvEqui + + +class MatformerConfig(BaseSettings): + """Hyperparameter schema for jarvisdgl.models.cgcnn.""" + + name: Literal["matformer"] + conv_layers: int = 3 + edge_layers: int = 1 + atom_input_features: int = 92 + edge_features: int = 256 + triplet_input_features: int = 256 + node_features: int = 256 + fc_layers: int = 1 + fc_features: int = 256 + output_features: int = 1 + node_layer_head: int = 1 + edge_layer_head: int = 1 + nn_based: bool = False + + link: Literal["identity", "log", "logit"] = "identity" + zero_inflated: bool = False + use_angle: bool = False + angle_lattice: bool = False + classification: bool = False + + class Config: + """Configure model settings behavior.""" + + env_prefix = "jv_model" + + +def bond_cosine(r1, r2): + bond_cosine = torch.sum(r1 * r2, dim=-1) / ( + torch.norm(r1, dim=-1) * torch.norm(r2, dim=-1) + ) + bond_cosine = torch.clamp(bond_cosine, -1, 1) + return bond_cosine + +class MatformerEquivariant(nn.Module): + """att pyg implementation.""" + + def __init__(self, config: MatformerConfig = MatformerConfig(name="matformer")): + """Set up att modules.""" + super().__init__() + print("Using equivariant marformer !!!!!!!!!!!!!!!!!!!!!!!!") + self.classification = config.classification + self.use_angle = config.use_angle + self.atom_embedding = nn.Linear( + config.atom_input_features, config.node_features + ) + self.rbf = nn.Sequential( + RBFExpansion( + vmin=-4.0, + vmax=0.0, + bins=config.edge_features, + ), + nn.Linear(config.edge_features, config.node_features), + nn.Softplus(), + # nn.Linear(config.node_features, config.node_features), + ) + + self.rbf_angle = nn.Sequential( + RBFExpansion( + vmin=-1.0, + vmax=1.0, + bins=config.triplet_input_features, + ), + nn.Linear(config.triplet_input_features, config.node_features), + nn.Softplus(), + # nn.Linear(config.node_features, config.node_features), + ) + + self.att_layers = nn.ModuleList( + [ + MatformerConv(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + for _ in range(config.conv_layers) + ] + ) + + self.edge_update_layer = MatformerConv_edge(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + + self.equi_update = MatformerConvEqui(in_channels=config.node_features, out_channels=config.node_features, edge_dim=config.node_features, use_second_order_repr=True) + + self.fc = nn.Sequential( + nn.Linear(config.node_features, config.fc_features), nn.SiLU() + ) + self.sigmoid = nn.Sigmoid() + + if self.classification: + self.fc_out = nn.Linear(config.fc_features, 2) + self.softmax = nn.LogSoftmax(dim=1) + else: + self.fc_out = nn.Linear( + config.fc_features, config.output_features + ) + + self.link = None + self.link_name = config.link + if config.link == "identity": + self.link = lambda x: x + elif config.link == "log": + self.link = torch.exp + avg_gap = 0.7 # magic number -- average bandgap in dft_3d + if not self.zero_inflated: + self.fc_out.bias.data = torch.tensor( + np.log(avg_gap), dtype=torch.float + ) + elif config.link == "logit": + self.link = torch.sigmoid + + def forward(self, data) -> torch.Tensor: + data, ldata, lattice = data + node_features = self.atom_embedding(data.x) + n_nodes = node_features.shape[0] + edge_feat = -0.75 / torch.norm(data.edge_attr, dim=1) + # lat_feat = -0.75 / torch.norm(data.atom_lat.view(n_nodes * 3, 3), dim=1) + # edge_nei_len = -0.75 / torch.norm(data.edge_nei, dim=-1) # [num_edges, 3] + # edge_nei_angle = bond_cosine(data.edge_nei, data.edge_attr.unsqueeze(1).repeat(1, 3, 1)) # [num_edges, 3, 3] -> [num_edges, 3] + num_edge = edge_feat.shape[0] + edge_features = self.rbf(edge_feat) + # lat_features = self.rbf(lat_feat).view(n_nodes, 3, -1) + # edge_nei_len = self.rbf(edge_nei_len.view(-1)).view(num_edge, 3, -1) + # edge_nei_angle = self.rbf_angle(edge_nei_angle.view(-1)).view(num_edge, 3, -1) + + node_features = self.att_layers[0](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # edge_features = self.edge_update_layer(edge_features, edge_nei_len, edge_nei_angle) # / math.sqrt(4) + # node_features = self.att_layers[1](node_features, data.edge_index, edge_features) # / math.sqrt(16) + node_features = self.equi_update(data, node_features, data.edge_index, edge_features) + node_features = self.att_layers[2](node_features, data.edge_index, edge_features) # / math.sqrt(16) + + # crystal-level readout + features = scatter(node_features, data.batch, dim=0, reduce="mean") + + # features = F.softplus(features) + features = self.fc(features) + + out = self.fc_out(features) + if self.link: + out = self.link(out) + if self.classification: + out = self.softmax(out) + + return torch.squeeze(out) + + + + +class MatformerInvariant(nn.Module): + """att pyg implementation.""" + + def __init__(self, config: MatformerConfig = MatformerConfig(name="matformer")): + """Set up att modules.""" + super().__init__() + print("Using invariant marformer !!!!!!!!!!!!!!!!!!!!!!!!") + self.classification = config.classification + self.use_angle = config.use_angle + self.atom_embedding = nn.Linear( + config.atom_input_features, config.node_features + ) + self.rbf = nn.Sequential( + RBFExpansion( + vmin=-4.0, + vmax=0.0, + bins=config.edge_features, + ), + nn.Linear(config.edge_features, config.node_features), + nn.Softplus(), + ) + + self.rbf_angle = nn.Sequential( + RBFExpansion( + vmin=-1.0, + vmax=1.0, + bins=config.triplet_input_features, + ), + nn.Linear(config.triplet_input_features, config.node_features), + nn.Softplus(), + ) + + self.att_layers = nn.ModuleList( + [ + MatformerConv(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + for _ in range(config.conv_layers) + ] + ) + + self.edge_update_layer = MatformerConv_edge(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + + self.fc = nn.Sequential( + nn.Linear(config.node_features, config.fc_features), nn.SiLU() + ) + self.sigmoid = nn.Sigmoid() + + if self.classification: + self.fc_out = nn.Linear(config.fc_features, 2) + self.softmax = nn.LogSoftmax(dim=1) + else: + self.fc_out = nn.Linear( + config.fc_features, config.output_features + ) + + self.link = None + self.link_name = config.link + if config.link == "identity": + self.link = lambda x: x + elif config.link == "log": + self.link = torch.exp + avg_gap = 0.7 # magic number -- average bandgap in dft_3d + if not self.zero_inflated: + self.fc_out.bias.data = torch.tensor( + np.log(avg_gap), dtype=torch.float + ) + elif config.link == "logit": + self.link = torch.sigmoid + + def forward(self, data) -> torch.Tensor: + data, ldata, lattice = data + node_features = self.atom_embedding(data.x) + edge_feat = -0.75 / torch.norm(data.edge_attr, dim=1) # [num_edges] + edge_nei_len = -0.75 / torch.norm(data.edge_nei, dim=-1) # [num_edges, 3] + edge_nei_angle = bond_cosine(data.edge_nei, data.edge_attr.unsqueeze(1).repeat(1, 3, 1)) # [num_edges, 3, 3] -> [num_edges, 3] + num_edge = edge_feat.shape[0] + edge_features = self.rbf(edge_feat) + edge_nei_len = self.rbf(edge_nei_len.reshape(-1)).reshape(num_edge, 3, -1) + edge_nei_angle = self.rbf_angle(edge_nei_angle.reshape(-1)).reshape(num_edge, 3, -1) + + node_features = self.att_layers[0](node_features, data.edge_index, edge_features) # / math.sqrt(16) + edge_features = self.edge_update_layer(edge_features, edge_nei_len, edge_nei_angle) # / math.sqrt(4) + node_features = self.att_layers[1](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # edge_features = self.edge_update_layers[1](edge_features, ldata.edge_index, angle_features) + node_features = self.att_layers[2](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # node_features = self.att_layers[3](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # node_features = self.att_layers[4](node_features, data.edge_index, edge_features) # / math.sqrt(16) + + # crystal-level readout + features = scatter(node_features, data.batch, dim=0, reduce="mean") + + # features = F.softplus(features) + features = self.fc(features) + + out = self.fc_out(features) + if self.link: + out = self.link(out) + if self.classification: + out = self.softmax(out) + + return torch.squeeze(out) + + diff --git a/benchmarks/matbench_v0.1_eComFormer/models/transformer.py b/benchmarks/matbench_v0.1_eComFormer/models/transformer.py new file mode 100644 index 00000000..cf504556 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/models/transformer.py @@ -0,0 +1,282 @@ +import math +from e3nn import o3 +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch_sparse import SparseTensor +import torch.nn as nn + +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.typing import Adj, OptTensor, PairTensor +from matformer.models.utils import softmax +from torch_scatter import scatter + + +class MatformerConv(MessagePassing): + _alpha: OptTensor + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super(MatformerConv, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + self._alpha = None + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + self.lin_edge = nn.Linear(edge_dim, heads * out_channels) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + + self.lin_msg_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.softplus = nn.Softplus() + self.silu = nn.SiLU() + self.key_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.bn = nn.BatchNorm1d(out_channels) + self.bn_att = nn.BatchNorm1d(out_channels) + self.sigmoid = nn.Sigmoid() + print('I am using the correct version of matformer') + + def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, + edge_attr: OptTensor = None, return_attention_weights=None): + + H, C = self.heads, self.out_channels + if isinstance(x, Tensor): + x: PairTensor = (x, x) + + query = self.lin_query(x[1]).view(-1, H, C) + key = self.lin_key(x[0]).view(-1, H, C) + value = self.lin_value(x[0]).view(-1, H, C) + + out = self.propagate(edge_index, query=query, key=key, value=value, + edge_attr=edge_attr, size=None) + + out = out.view(-1, self.heads * self.out_channels) + out = self.lin_concate(out) + + return self.softplus(x[1] + self.bn(out)) + + def message(self, query_i: Tensor, key_i: Tensor, key_j: Tensor, value_j: Tensor, value_i: Tensor, + edge_attr: OptTensor, index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: + + edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels) + key_j = self.key_update(torch.cat((key_i, key_j, edge_attr), dim=-1)) + alpha = (query_i * key_j) / math.sqrt(self.out_channels) + out = self.lin_msg_update(torch.cat((value_i, value_j, edge_attr), dim=-1)) + out = out * self.sigmoid(self.bn_att(alpha.view(-1, self.out_channels)).view(-1, self.heads, self.out_channels)) + return out + + +class MatformerConv_edge(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + self.lemb = nn.Embedding(num_embeddings=3, embedding_dim=32) + self.embedding_dim = 32 + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + # for test + self.lin_key_e1 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_value_e1 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_key_e2 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_value_e2 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_key_e3 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_value_e3 = nn.Linear(in_channels[0], heads * out_channels) + # for test ends + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + self.lin_edge_len = nn.Linear(in_channels[0] + self.embedding_dim, in_channels[0]) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + self.lin_msg_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.silu = nn.SiLU() + self.softplus = nn.Softplus() + self.key_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.bn_att = nn.BatchNorm1d(out_channels) + + self.bn = nn.BatchNorm1d(out_channels) + self.sigmoid = nn.Sigmoid() + print('I am using the invariant version of EPCNet') + + def forward(self, edge: Union[Tensor, PairTensor], edge_nei_len: OptTensor = None, edge_nei_angle: OptTensor = None): + # preprocess for edge of shape [num_edges, hidden_dim] + + H, C = self.heads, self.out_channels + if isinstance(edge, Tensor): + edge: PairTensor = (edge, edge) + device = edge[1].device + query_x = self.lin_query(edge[1]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + key_x = self.lin_key(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + value_x = self.lin_value(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + num_edge = query_x.shape[0] + # preprocess for edge_neighbor of shape [num_edges, 3, hidden_dim] + # lembs = torch.cat((self.lemb(torch.tensor([0]).to(device)), self.lemb(torch.tensor([1]).to(device)), self.lemb(torch.tensor([2]).to(device))), dim=0).unsqueeze(0).repeat(num_edge, 1, 1) + # edge_nei_len = self.lin_edge_len(torch.cat((edge_nei_len, lembs), dim=-1)) + # query_y = self.lin_query(edge_nei_len).view(-1, 3, H, C) + # key_y = self.lin_key(edge_nei_len).view(-1, 3, H, C) + # value_y = self.lin_value(edge_nei_len).view(-1, 3, H, C) + + # test begin + key_y = torch.cat((self.lin_key_e1(edge_nei_len[:,0,:]).view(-1, 1, H, C), + self.lin_key_e2(edge_nei_len[:,1,:]).view(-1, 1, H, C), + self.lin_key_e3(edge_nei_len[:,2,:]).view(-1, 1, H, C)), dim=1) + value_y = torch.cat((self.lin_value_e1(edge_nei_len[:,0,:]).view(-1, 1, H, C), + self.lin_value_e2(edge_nei_len[:,1,:]).view(-1, 1, H, C), + self.lin_value_e3(edge_nei_len[:,2,:]).view(-1, 1, H, C)), dim=1) + # test end + + # preprocess for interaction of shape [num_edges, 3, hidden_dim] + edge_xy = self.lin_edge(edge_nei_angle).view(-1, 3, H, C) + + key = self.key_update(torch.cat((key_x, key_y, edge_xy), dim=-1)) + alpha = (query_x * key) / math.sqrt(self.out_channels) + out = self.lin_msg_update(torch.cat((value_x, value_y, edge_xy), dim=-1)) + out = out * self.sigmoid(self.bn_att(alpha.view(-1, self.out_channels)).view(-1, 3, self.heads, self.out_channels)) + + out = out.view(-1, 3, self.heads * self.out_channels) + out = self.lin_concate(out) + # aggregate the msg + out = out.sum(dim=1) + + return self.softplus(edge[1] + self.bn(out)) + + + +class TensorProductConvLayer(torch.nn.Module): + # from Torsional diffusion + def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True): + super(TensorProductConvLayer, self).__init__() + self.in_irreps = in_irreps + self.out_irreps = out_irreps + self.sh_irreps = sh_irreps + self.residual = residual + + self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False) + + self.fc = nn.Sequential( + nn.Linear(n_edge_features, n_edge_features), + nn.Softplus(), + nn.Linear(n_edge_features, tp.weight_numel) + ) + + def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean'): + + edge_src, edge_dst = edge_index + tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr)) + + out_nodes = out_nodes or node_attr.shape[0] + out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce) + if self.residual: + padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1])) + out = out + padded + + return out + + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 128, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{ns}x0e' + ] + self.ns, self.nv = ns, nv + self.node_linear = nn.Linear(in_channels, ns) + self.skip_linear = nn.Linear(in_channels, out_channels) + self.sh = '1x0e + 1x1o + 1x2e' + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[2], + n_edge_features=edge_dim, + residual=False + ) + self.softplus = nn.Softplus() + self.bn = nn.BatchNorm1d(ns) + self.node_linear_2 = nn.Linear(ns, out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + edge_nei_len: OptTensor = None): + edge_vec = data.edge_attr + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + n_ = node_feature.shape[0] + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.nlayer_2(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.softplus(self.node_linear_2(self.softplus(self.bn(node_feature)))) + node_feature += self.skip_linear(skip_connect) + + return node_feature \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_eComFormer/models/utils.py b/benchmarks/matbench_v0.1_eComFormer/models/utils.py new file mode 100644 index 00000000..aa01ef1b --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/models/utils.py @@ -0,0 +1,126 @@ +"""Shared model-building components.""" +from typing import Optional + +import numpy as np +import torch +from torch import nn + +from torch import Tensor +from torch_scatter import gather_csr, scatter, segment_csr + +from torch_geometric.utils.num_nodes import maybe_num_nodes + +class RBFExpansion(nn.Module): + """Expand interatomic distances with radial basis functions.""" + + def __init__( + self, + vmin: float = 0, + vmax: float = 8, + bins: int = 40, + lengthscale: Optional[float] = None, + ): + """Register torch parameters for RBF expansion.""" + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.register_buffer( + "centers", torch.linspace(self.vmin, self.vmax, self.bins) + ) + + if lengthscale is None: + # SchNet-style + # set lengthscales relative to granularity of RBF expansion + self.lengthscale = np.diff(self.centers).mean() + self.gamma = 1 / self.lengthscale + + else: + self.lengthscale = lengthscale + self.gamma = 1 / (lengthscale ** 2) + + def forward(self, distance: torch.Tensor) -> torch.Tensor: + """Apply RBF expansion to interatomic distance tensor.""" + return torch.exp( + -self.gamma * (distance.unsqueeze(1) - self.centers) ** 2 + ) + + +@torch.jit.script +def softmax(src: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, + dim: int = 0) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the softmax based on + sorted inputs in CSR representation. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + :rtype: :class:`Tensor` + """ + if ptr is not None: + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + ptr = ptr.view(size) + src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) + out = (src - src_max).exp() + out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src, index, dim, dim_size=N, reduce='max') + src_max = src_max.index_select(dim, index) + out = (src - src_max).exp() + out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + out_sum = out_sum.index_select(dim, index) + else: + raise NotImplementedError + + return out / (out_sum + 1e-16) + + +@torch.jit.script +def softmax_vec(src: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, + dim: int = 0) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the softmax based on + sorted inputs in CSR representation. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + :rtype: :class:`Tensor` + """ + if ptr is not None: + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + ptr = ptr.view(size) + src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) + out = (src - src_max).exp() + out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src, index, dim, dim_size=N, reduce='max') + src_max = src_max.index_select(dim, index) + out = (src - src_max).exp() + out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + out_sum = out_sum.index_select(dim, index) + else: + raise NotImplementedError + + return out / (out_sum + 1e-16) \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_eComFormer/results.json.gz b/benchmarks/matbench_v0.1_eComFormer/results.json.gz new file mode 100644 index 00000000..a5e0fbbe Binary files /dev/null and b/benchmarks/matbench_v0.1_eComFormer/results.json.gz differ diff --git a/benchmarks/matbench_v0.1_eComFormer/scheduler.py b/benchmarks/matbench_v0.1_eComFormer/scheduler.py new file mode 100644 index 00000000..1c97817b --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/scheduler.py @@ -0,0 +1,244 @@ +import types +import math +import torch +from torch import inf +from functools import wraps +import warnings +import weakref +from collections import Counter +from bisect import bisect_right +from torch.optim import Optimizer + +class LRScheduler: + + def __init__(self, optimizer, last_epoch=-1, verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.verbose = verbose + + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr(self, is_verbose, group, lr, epoch=None): + """Display the current learning rate. + """ + if is_verbose: + if epoch is None: + print('Adjusting learning rate' + ' of group {} to {:.4e}.'.format(group, lr)) + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: adjusting learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) + + + def step(self, epoch=None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (int): The power of the polynomial. Default: 1.0. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False, end_lr=1e-5, decay_steps=10): + self.total_iters = total_iters + self.power = power + self.end_lr = end_lr + self.decay_steps = decay_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + # print(self.last_epoch) + # print(self._step_count) + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - step / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + # if self.last_epoch == 0 or self.last_epoch > self.total_iters: + # return [group["lr"] for group in self.optimizer.param_groups] + + # decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power + # return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + + print(self.last_epoch) + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - self.last_epoch / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # return [ + # ( + # base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power + # ) + # for base_lr in self.base_lrs + # ] + + # def decayed_learning_rate(step): + # step = min(step, decay_steps) + # return ((initial_learning_rate - end_learning_rate) * + # (1 - step / decay_steps) ^ (power) + # ) + end_learning_rate \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_eComFormer/train.py b/benchmarks/matbench_v0.1_eComFormer/train.py new file mode 100644 index 00000000..ff506fa4 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/train.py @@ -0,0 +1,913 @@ +from functools import partial + +# from pathlib import Path +from typing import Any, Dict, Union + +import ignite +import torch + +from ignite.contrib.handlers import TensorboardLogger +try: + from ignite.contrib.handlers.stores import EpochOutputStore +except Exception as exp: + from ignite.handlers.stores import EpochOutputStore + + pass +from ignite.handlers import EarlyStopping +from ignite.contrib.handlers.tensorboard_logger import ( + global_step_from_engine, +) +from ignite.contrib.handlers.tqdm_logger import ProgressBar +from ignite.engine import ( + Events, + create_supervised_evaluator, + create_supervised_trainer, +) +from ignite.contrib.metrics import ROC_AUC, RocCurve +from ignite.metrics import ( + Accuracy, + Precision, + Recall, + ConfusionMatrix, +) +import pickle as pk +import numpy as np +from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan +from ignite.metrics import Loss, MeanAbsoluteError +from torch import nn +from matformer import models +from matformer.data import get_train_val_loaders +from matformer.config import TrainingConfig +# from matformer.models.pyg_att import Matformer + +from jarvis.db.jsonutils import dumpjson +import json +import pprint + +import os + +# import sys +# sys.path.append("/mnt/data/shared/congfu/CompCrystal/NewModel_27/matformer/") +# from scheduler import PolynomialLR + + +import types +import math +import torch +from torch import inf +from functools import wraps +import warnings +import weakref +from collections import Counter +from bisect import bisect_right +from torch.optim import Optimizer + +class LRScheduler: + + def __init__(self, optimizer, last_epoch=-1, verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.verbose = verbose + + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr(self, is_verbose, group, lr, epoch=None): + """Display the current learning rate. + """ + if is_verbose: + if epoch is None: + print('Adjusting learning rate' + ' of group {} to {:.4e}.'.format(group, lr)) + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: adjusting learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) + + + def step(self, epoch=None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (int): The power of the polynomial. Default: 1.0. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False, end_lr=1e-5, decay_steps=10): + self.total_iters = total_iters + self.power = power + self.end_lr = end_lr + self.decay_steps = decay_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + # print(self.last_epoch) + # print(self._step_count) + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - step / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + # if self.last_epoch == 0 or self.last_epoch > self.total_iters: + # return [group["lr"] for group in self.optimizer.param_groups] + + # decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power + # return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + + print(self.last_epoch) + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - self.last_epoch / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # return [ + # ( + # base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power + # ) + # for base_lr in self.base_lrs + # ] + + # def decayed_learning_rate(step): + # step = min(step, decay_steps) + # return ((initial_learning_rate - end_learning_rate) * + # (1 - step / decay_steps) ^ (power) + # ) + end_learning_rate + +########################################################################################### + + +# torch config +torch.set_default_dtype(torch.float32) + +device = "cpu" +if torch.cuda.is_available(): + device = torch.device("cuda") + + +def activated_output_transform(output): + """Exponentiate output.""" + y_pred, y = output + y_pred = torch.exp(y_pred) + y_pred = y_pred[:, 1] + return y_pred, y + + +def make_standard_scalar_and_pca(output): + """Use standard scalar and PCS for multi-output data.""" + sc = pk.load(open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")) + y_pred, y = output + y_pred = torch.tensor(sc.transform(y_pred.cpu().numpy()), device=device) + y = torch.tensor(sc.transform(y.cpu().numpy()), device=device) + return y_pred, y + + +def thresholded_output_transform(output): + """Round off output.""" + y_pred, y = output + y_pred = torch.round(torch.exp(y_pred)) + # print ('output',y_pred) + return y_pred, y + + +def group_decay(model): + """Omit weight decay from bias and batchnorm params.""" + decay, no_decay = [], [] + + for name, p in model.named_parameters(): + if "bias" in name or "bn" in name or "norm" in name: + no_decay.append(p) + else: + decay.append(p) + + return [ + {"params": decay}, + {"params": no_decay, "weight_decay": 0}, + ] + + +def setup_optimizer(params, config: TrainingConfig): + """Set up optimizer for param groups.""" + if config.optimizer == "adamw": + optimizer = torch.optim.AdamW( + params, + lr=config.learning_rate, + weight_decay=config.weight_decay, + ) + elif config.optimizer == "sgd": + optimizer = torch.optim.SGD( + params, + lr=config.learning_rate, + momentum=0.9, + weight_decay=config.weight_decay, + ) + return optimizer + + +def train_dgl( + config: Union[TrainingConfig, Dict[str, Any]], + model: nn.Module = None, + train_val_test_loaders=[], + test_only=False, + use_save=True, + mp_id_list=None, + train_inputs=None, + train_outputs=None, + test_inputs=None, + test_outputs=None, + model_variant=None, +): + """ + `config` should conform to matformer.conf.TrainingConfig, and + if passed as a dict with matching keys, pydantic validation is used + """ + print(config) + if type(config) is dict: + try: + config = TrainingConfig(**config) + except Exception as exp: + print("Check", exp) + print('error in converting to training config!') + import os + + if not os.path.exists(config.output_dir): + os.makedirs(config.output_dir) + checkpoint_dir = os.path.join(config.output_dir) + deterministic = False + classification = False + print("config:") + tmp = config.dict() + f = open(os.path.join(config.output_dir, "config.json"), "w") + f.write(json.dumps(tmp, indent=4)) + f.close() + global tmp_output_dir + tmp_output_dir = config.output_dir + pprint.pprint(tmp) + if config.classification_threshold is not None: + classification = True + if config.random_seed is not None: + deterministic = True + ignite.utils.manual_seed(config.random_seed) + + # import pdb; pdb.set_trace() + line_graph = True + if not train_val_test_loaders: + # use input standardization for all real-valued feature sets + ( + train_loader, + val_loader, + test_loader, + prepare_batch, + mean_train, + std_train, + ) = get_train_val_loaders( + dataset=config.dataset, + target=config.target, + n_train=config.n_train, + n_val=config.n_val, + n_test=config.n_test, + train_ratio=config.train_ratio, + val_ratio=config.val_ratio, + test_ratio=config.test_ratio, + batch_size=config.batch_size, + atom_features=config.atom_features, + neighbor_strategy=config.neighbor_strategy, + standardize=config.atom_features != "cgcnn", + line_graph=line_graph, + id_tag=config.id_tag, + pin_memory=config.pin_memory, + workers=config.num_workers, + save_dataloader=config.save_dataloader, + use_canonize=config.use_canonize, + filename=config.filename, + cutoff=config.cutoff, + max_neighbors=config.max_neighbors, + output_features=config.model.output_features, + classification_threshold=config.classification_threshold, + target_multiplication_factor=config.target_multiplication_factor, + standard_scalar_and_pca=config.standard_scalar_and_pca, + keep_data_order=config.keep_data_order, + output_dir=config.output_dir, + matrix_input=config.matrix_input, + pyg_input=config.pyg_input, + use_lattice=config.use_lattice, + use_angle=config.use_angle, + use_save=use_save, + mp_id_list=mp_id_list, + train_inputs=train_inputs, + train_outputs=train_outputs, + test_inputs=test_inputs, + test_outputs=test_outputs, + ) + else: + train_loader = train_val_test_loaders[0] + val_loader = train_val_test_loaders[1] + test_loader = train_val_test_loaders[2] + prepare_batch = train_val_test_loaders[3] + prepare_batch = partial(prepare_batch, device=device) + if classification: + config.model.classification = True + # define network, optimizer, scheduler + if model_variant == 'matformerinvariant': + from matformer.models.pyg_att import MatformerInvariant as Matformer + elif model_variant == 'matformerequivariant': + from matformer.models.pyg_att import MatformerEquivariant as Matformer + _model = { + "matformer" : Matformer, + } + if std_train is None: + std_train = 1.0 + print('std train is none!') + print('std train:', std_train) + if model is None: + net = _model.get(config.model.name)(config.model) + print("config:") + pprint.pprint(config.model.dict()) + else: + net = model + + net.to(device) + if config.distributed: + import torch.distributed as dist + import os + + def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + def cleanup(): + dist.destroy_process_group() + + setup(2, 2) + net = torch.nn.parallel.DistributedDataParallel( + net + ) + params = group_decay(net) + optimizer = setup_optimizer(params, config) + + if config.scheduler == "none": + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lambda epoch: 1.0 + ) + + elif config.scheduler == "onecycle": + steps_per_epoch = len(train_loader) + pct_start = config.warmup_steps / (config.epochs * steps_per_epoch) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=config.learning_rate, + epochs=config.epochs, + steps_per_epoch=steps_per_epoch, + # pct_start=pct_start, + pct_start=0.3, + ) + elif config.scheduler == "step": + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=100000, + gamma=0.96, + ) + elif config.scheduler == "polynomial": + steps_per_epoch = len(train_loader) + num_steps = config.epochs * steps_per_epoch + scheduler = PolynomialLR( + optimizer, + decay_steps=num_steps, + end_lr=1e-5, + ) + + # select configured loss function + criteria = { + "mse": nn.MSELoss(), + "l1": nn.L1Loss(), + } + criterion = criteria[config.criterion] + # set up training engine and evaluators + metrics = {"loss": Loss(criterion), "mae": MeanAbsoluteError() * std_train, "neg_mae": -1.0 * MeanAbsoluteError() * std_train} + trainer = create_supervised_trainer( + net, + optimizer, + criterion, + prepare_batch=prepare_batch, + device=device, + deterministic=deterministic, + ) + evaluator = create_supervised_evaluator( + net, + metrics=metrics, + prepare_batch=prepare_batch, + device=device, + ) + train_evaluator = create_supervised_evaluator( + net, + metrics=metrics, + prepare_batch=prepare_batch, + device=device, + ) + if test_only: + checkpoint_tmp = torch.load('/your_model_path.pt') + to_load = { + "model": net, + "optimizer": optimizer, + "lr_scheduler": scheduler, + "trainer": trainer, + } + Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_tmp) + net.eval() + targets = [] + predictions = [] + import time + t1 = time.time() + with torch.no_grad(): + for dat in test_loader: + g, lg, _, target = dat + try: + out_data = net([g.to(device), lg.to(device), _.to(device)]) + success_flag=1 + except: # just in case + print('error for this data') + print(g) + success_flag=0 + if success_flag > 0: + out_data = out_data.cpu().numpy().tolist() + target = target.cpu().numpy().flatten().tolist() + if len(target) == 1: + target = target[0] + targets.append(target) + predictions.append(out_data) + t2 = time.time() + f.close() + from sklearn.metrics import mean_absolute_error + targets = np.array(targets) * std_train + predictions = np.array(predictions) * std_train + print("Test MAE:", mean_absolute_error(targets, predictions)) + print("Total test time:", t2-t1) + return mean_absolute_error(targets, predictions) + # ignite event handlers: + trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan()) + + # apply learning rate scheduler + trainer.add_event_handler( + Events.ITERATION_COMPLETED, lambda engine: scheduler.step() + ) + + # checkpoint_tmp = torch.load("/mnt/data/shared/congfu/CompCrystal/NewModel_27/matformer/scripts/matbench_mp_e_form_equivariant_max25_epoch500_lr1e-3_L1_fold1/checkpoint_299.pt") + # to_load = { + # "model": net, + # "optimizer": optimizer, + # "lr_scheduler": scheduler, + # "trainer": trainer, + # } + # Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_tmp) + # print('checkpoint.pt loaded') + # print('current epoch:', trainer.state.epoch) + # print('current optimizer:', optimizer) + # print('current scheduler:', scheduler) + + if config.write_checkpoint: + # model checkpointing + to_save = { + "model": net, + "optimizer": optimizer, + "lr_scheduler": scheduler, + "trainer": trainer, + } + handler = Checkpoint( + to_save, + DiskSaver(checkpoint_dir, create_dir=True, require_empty=False), + n_saved=2, + global_step_transform=lambda *_: trainer.state.epoch, + ) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + # evaluate save + to_save = {"model": net} + handler = Checkpoint( + to_save, + DiskSaver(checkpoint_dir, create_dir=True, require_empty=False), + n_saved=5, + filename_prefix='best', + score_name="neg_mae", + global_step_transform=lambda *_: trainer.state.epoch, + ) + evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler) + if config.progress: + pbar = ProgressBar() + pbar.attach(trainer, output_transform=lambda x: {"loss": x}) + # pbar.attach(evaluator,output_transform=lambda x: {"mae": x}) + + history = { + "train": {m: [] for m in metrics.keys()}, + "validation": {m: [] for m in metrics.keys()}, + } + + if config.store_outputs: + # in history["EOS"] + eos = EpochOutputStore() + eos.attach(evaluator) + train_eos = EpochOutputStore() + train_eos.attach(train_evaluator) + + # collect evaluation performance + @trainer.on(Events.EPOCH_COMPLETED) + def log_results(engine): + """Print training and validation metrics to console.""" + # train_evaluator.run(train_loader) + # evaluator.run(val_loader) + + # tmetrics = train_evaluator.state.metrics + # vmetrics = evaluator.state.metrics + # for metric in metrics.keys(): + # tm = tmetrics[metric] + # vm = vmetrics[metric] + # if metric == "roccurve": + # tm = [k.tolist() for k in tm] + # vm = [k.tolist() for k in vm] + # if isinstance(tm, torch.Tensor): + # tm = tm.cpu().numpy().tolist() + # vm = vm.cpu().numpy().tolist() + + # history["train"][metric].append(tm) + # history["validation"][metric].append(vm) + + # train_evaluator.run(train_loader) + evaluator.run(val_loader) + + vmetrics = evaluator.state.metrics + for metric in metrics.keys(): + vm = vmetrics[metric] + t_metric = metric + if metric == "roccurve": + vm = [k.tolist() for k in vm] + if isinstance(vm, torch.Tensor): + vm = vm.cpu().numpy().tolist() + + history["validation"][metric].append(vm) + + + + epoch_num = len(history["validation"][t_metric]) + if epoch_num % 20 == 0: + train_evaluator.run(train_loader) + tmetrics = train_evaluator.state.metrics + for metric in metrics.keys(): + tm = tmetrics[metric] + if metric == "roccurve": + tm = [k.tolist() for k in tm] + if isinstance(tm, torch.Tensor): + tm = tm.cpu().numpy().tolist() + + history["train"][metric].append(tm) + else: + tmetrics = {} + tmetrics['mae'] = -1 + + + # for metric in metrics.keys(): + # history["train"][metric].append(tmetrics[metric]) + # history["validation"][metric].append(vmetrics[metric]) + + if config.store_outputs: + history["EOS"] = eos.data + history["trainEOS"] = train_eos.data + dumpjson( + filename=os.path.join(config.output_dir, "history_val.json"), + data=history["validation"], + ) + dumpjson( + filename=os.path.join(config.output_dir, "history_train.json"), + data=history["train"], + ) + if config.progress: + pbar = ProgressBar() + if not classification: + pbar.log_message(f"Val_MAE: {vmetrics['mae']:.4f}") + pbar.log_message(f"Train_MAE: {tmetrics['mae']:.4f}") + else: + pbar.log_message(f"Train ROC AUC: {tmetrics['rocauc']:.4f}") + pbar.log_message(f"Val ROC AUC: {vmetrics['rocauc']:.4f}") + + if config.n_early_stopping is not None: + if classification: + my_metrics = "accuracy" + else: + my_metrics = "neg_mae" + + def default_score_fn(engine): + score = engine.state.metrics[my_metrics] + return score + + es_handler = EarlyStopping( + patience=config.n_early_stopping, + score_function=default_score_fn, + trainer=trainer, + ) + evaluator.add_event_handler(Events.EPOCH_COMPLETED, es_handler) + # optionally log results to tensorboard + if config.log_tensorboard: + + tb_logger = TensorboardLogger( + log_dir=os.path.join(config.output_dir, "tb_logs", "test") + ) + for tag, evaluator in [ + ("training", train_evaluator), + ("validation", evaluator), + ]: + tb_logger.attach_output_handler( + evaluator, + event_name=Events.EPOCH_COMPLETED, + tag=tag, + metric_names=["loss", "mae"], + global_step_transform=global_step_from_engine(trainer), + ) + + trainer.run(train_loader, max_epochs=config.epochs) + + if config.log_tensorboard: + test_loss = evaluator.state.metrics["loss"] + tb_logger.writer.add_hparams(config, {"hparam/test_loss": test_loss}) + tb_logger.close() + if config.write_predictions and classification: + net.eval() + f = open( + os.path.join(config.output_dir, "prediction_results_test_set.csv"), + "w", + ) + f.write("id,target,prediction\n") + targets = [] + predictions = [] + with torch.no_grad(): + ids = test_loader.dataset.ids # [test_loader.dataset.indices] + for dat, id in zip(test_loader, ids): + g, lg, target = dat + out_data = net([g.to(device), lg.to(device)]) + # out_data = torch.exp(out_data.cpu()) + top_p, top_class = torch.topk(torch.exp(out_data), k=1) + target = int(target.cpu().numpy().flatten().tolist()[0]) + + f.write("%s, %d, %d\n" % (id, (target), (top_class))) + targets.append(target) + predictions.append( + top_class.cpu().numpy().flatten().tolist()[0] + ) + f.close() + from sklearn.metrics import roc_auc_score + + print("predictions", predictions) + print("targets", targets) + print( + "Test ROCAUC:", + roc_auc_score(np.array(targets), np.array(predictions)), + ) + + if ( + config.write_predictions + and not classification + and config.model.output_features > 1 + ): + net.eval() + mem = [] + with torch.no_grad(): + ids = test_loader.dataset.ids # [test_loader.dataset.indices] + for dat, id in zip(test_loader, ids): + g, lg, target = dat + out_data = net([g.to(device), lg.to(device)]) + out_data = out_data.cpu().numpy().tolist() + if config.standard_scalar_and_pca: + sc = pk.load(open("sc.pkl", "rb")) + out_data = list( + sc.transform(np.array(out_data).reshape(1, -1))[0] + ) # [0][0] + target = target.cpu().numpy().flatten().tolist() + info = {} + info["id"] = id + info["target"] = target + info["predictions"] = out_data + mem.append(info) + dumpjson( + filename=os.path.join( + config.output_dir, "multi_out_predictions.json" + ), + data=mem, + ) + if ( + config.write_predictions + and not classification + and config.model.output_features == 1 + ): + net.eval() + f = open( + os.path.join(config.output_dir, "prediction_results_test_set.csv"), + "w", + ) + f.write("id,target,prediction\n") + targets = [] + predictions = [] + with torch.no_grad(): + for dat in test_loader: + g, lg, _, target = dat + out_data = net([g.to(device), lg.to(device), lg.to(device)]) + out_data = out_data.cpu().numpy().tolist() + target = target.cpu().numpy().flatten().tolist() + if len(target) == 1: + target = target[0] + targets.append(target) + predictions.append(out_data) + f.close() + from sklearn.metrics import mean_absolute_error + + print( + "Test MAE:", + mean_absolute_error(np.array(targets), np.array(predictions)) * std_train, + "STD train:", + std_train, + ) + if config.store_outputs and not classification: + x = [] + y = [] + for i in history["EOS"]: + x.append(i[0].cpu().numpy().tolist()) + y.append(i[1].cpu().numpy().tolist()) + x = np.array(x, dtype="float").flatten() + y = np.array(y, dtype="float").flatten() + f = open( + os.path.join( + config.output_dir, "prediction_results_train_set.csv" + ), + "w", + ) + # TODO: Add IDs + f.write("target,prediction\n") + for i, j in zip(x, y): + f.write("%6f, %6f\n" % (j, i)) + line = str(i) + "," + str(j) + "\n" + f.write(line) + f.close() + # return history + return np.array(targets) * std_train + mean_train, np.array(predictions) * std_train + mean_train + + diff --git a/benchmarks/matbench_v0.1_eComFormer/train.sh b/benchmarks/matbench_v0.1_eComFormer/train.sh new file mode 100644 index 00000000..4c605915 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/train.sh @@ -0,0 +1,18 @@ +#!/bin/sh + +GPU=1 + +fold=0 + +CUDA_VISIBLE_DEVICES=${GPU} \ +python ./train.py \ +--output_dir="../matbench_mp_e_form_equivariant_max25_epoch500_lr1e-3_L1_fold"$fold \ +--max_neighbors=25 \ +--epochs=500 \ +--batch_size=64 \ +--task_name="matbench_mp_e_form" \ +--lr=1e-3 \ +--criterion='l1' \ +--fold_num=$fold \ +--multi_GPU \ +--model_variant="matformerequivariant" diff --git a/benchmarks/matbench_v0.1_eComFormer/train_props.py b/benchmarks/matbench_v0.1_eComFormer/train_props.py new file mode 100644 index 00000000..378fc5c0 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/train_props.py @@ -0,0 +1,304 @@ +"""Helper function for high-throughput GNN trainings.""" +"""Implementation based on the template of ALIGNN.""" +import matplotlib.pyplot as plt + +# import numpy as np +import time +from matformer.train import train_dgl +import os +import numpy as np + +# from sklearn.metrics import mean_absolute_error +plt.switch_backend("agg") + + +def train_prop_model( + prop="", + dataset="dft_3d", + write_predictions=True, + name="pygatt", + save_dataloader=False, + train_ratio=None, + classification_threshold=None, + val_ratio=None, + test_ratio=None, + learning_rate=0.001, + batch_size=None, + scheduler=None, + n_epochs=None, + id_tag=None, + num_workers=None, + weight_decay=None, + edge_input_features=None, + triplet_input_features=None, + embedding_features=None, + hidden_features=None, + output_features=None, + random_seed=None, + n_early_stopping=None, + cutoff=None, + max_neighbors=None, + matrix_input=False, + pyg_input=False, + use_lattice=False, + use_angle=False, + output_dir=None, + neighbor_strategy="k-nearest", + test_only=False, + use_save=True, + mp_id_list=None, + file_name=None, + atom_features="cgcnn", + task_name=None, + save_dir=None, + criterion=None, + multi_GPU=False, + fold_num=None, + model_variant=None, +): + """Train models for a dataset and a property.""" + if scheduler is None: + scheduler = "onecycle" + # scheduler = "none" + if batch_size is None: + batch_size = 64 + if n_epochs is None: + n_epochs = 500 + if num_workers is None: + num_workers = 10 + config = { + "dataset": dataset, + "target": "label", #prop, + "epochs": n_epochs, # 00,#00, + "batch_size": batch_size, # 0, + "weight_decay": 1e-05, + "learning_rate": learning_rate, + "criterion": criterion, #'l1', #"mse", + "optimizer": "adamw", + "scheduler": scheduler, + "save_dataloader": save_dataloader, + "pin_memory": False, + "write_predictions": write_predictions, + "num_workers": num_workers, + "classification_threshold": classification_threshold, + "atom_features": atom_features, + "model": { + "name": name, + }, + } + if n_early_stopping is not None: + config["n_early_stopping"] = n_early_stopping + if cutoff is not None: + config["cutoff"] = cutoff + if max_neighbors is not None: + config["max_neighbors"] = max_neighbors + if weight_decay is not None: + config["weight_decay"] = weight_decay + if edge_input_features is not None: + config["model"]["edge_input_features"] = edge_input_features + if hidden_features is not None: + config["model"]["hidden_features"] = hidden_features + if embedding_features is not None: + config["model"]["embedding_features"] = embedding_features + if output_features is not None: + config["model"]["output_features"] = output_features + if random_seed is not None: + config["random_seed"] = random_seed + if file_name is not None: + config["filename"] = file_name + # if model_name is not None: + # config['model']['name']=model_name + config["matrix_input"] = matrix_input + config["pyg_input"] = pyg_input + config["use_lattice"] = use_lattice + config["use_angle"] = use_angle + config["model"]["use_angle"] = use_angle + config["neighbor_strategy"] = neighbor_strategy + # config["output_dir"] = '.' + if output_dir is not None: + config["output_dir"] = output_dir + + if id_tag is not None: + config["id_tag"] = id_tag + if train_ratio is not None: + config["train_ratio"] = train_ratio + if val_ratio is None: + raise ValueError("Enter val_ratio.") + + if test_ratio is None: + raise ValueError("Enter test_ratio.") + config["val_ratio"] = val_ratio + config["test_ratio"] = test_ratio + if dataset == "jv_3d": + # config["save_dataloader"]=True + config["num_workers"] = 4 + config["pin_memory"] = False + # config["learning_rate"] = 0.001 + # config["epochs"] = 300 + + if dataset == "mp_3d_2020": + config["id_tag"] = "id" + config["num_workers"] = 0 + if dataset == "megnet2": + config["id_tag"] = "id" + config["num_workers"] = 0 + if dataset == "megnet": + config["id_tag"] = "id" + if prop == "e_form" or prop == "gap pbe": + config["n_train"] = 60000 + config["n_val"] = 5000 + config["n_test"] = 4239 + # config["learning_rate"] = 0.01 + # config["epochs"] = 300 + config["num_workers"] = 8 + else: + config["n_train"] = 4664 + config["n_val"] = 393 + config["n_test"] = 393 + if dataset == "oqmd_3d_no_cfid": + config["id_tag"] = "_oqmd_entry_id" + config["num_workers"] = 0 + if dataset == "hmof" and prop == "co2_absp": + config["model"]["output_features"] = 5 + if dataset == "edos_pdos": + if prop == "edos_up": + config["model"]["output_features"] = 300 + elif prop == "pdos_elast": + config["model"]["output_features"] = 200 + else: + raise ValueError("Target not available.") + if dataset == "qm9_std_jctc": + config["id_tag"] = "id" + config["n_train"] = 110000 + config["n_val"] = 10000 + config["n_test"] = 10829 + + # config["batch_size"] = 64 + config["cutoff"] = 5.0 + config["standard_scalar_and_pca"] = False + + if dataset == "qm9_dgl": + config["id_tag"] = "id" + config["n_train"] = 110000 + config["n_val"] = 10000 + config["n_test"] = 10831 + config["standard_scalar_and_pca"] = False + config["batch_size"] = 64 + config["cutoff"] = 5.0 + if config["target"] == "all": + config["model"]["output_features"] = 12 + + # config["max_neighbors"] = 9 + + if dataset == "hpov": + config["id_tag"] = "id" + if dataset == "qm9": + config["id_tag"] = "id" + config["n_train"] = 110000 + config["n_val"] = 10000 + config["n_test"] = 13885 + config["batch_size"] = batch_size + config["cutoff"] = 5.0 + config["max_neighbors"] = 9 + # config['atom_features']='atomic_number' + if prop in ["homo", "lumo", "gap", "zpve", "U0", "U", "H", "G"]: + config["target_multiplication_factor"] = 27.211386024367243 + + if dataset == 'mpf': + config["id_tag"] = "id" + config["n_train"] = 169516 + config["n_val"] = 9417 + config["n_test"] = 9417 + + if test_only: + t1 = time.time() + result = train_dgl(config, test_only=test_only, use_save=use_save, mp_id_list=mp_id_list) + t2 = time.time() + print("test mae=", result) + print("Toal time:", t2 - t1) + print() + print() + print() + else: + # t1 = time.time() + # result = train_dgl(config, use_save=use_save, mp_id_list=mp_id_list) + # t2 = time.time() + # print("train=", result["train"]) + # print("validation=", result["validation"]) + # print("Toal time:", t2 - t1) + # print() + # print() + # print() + + from matbench.bench import MatbenchBenchmark + + mb = MatbenchBenchmark(subset=[task_name], autoload=False) + + print(f"Running task: {task_name} Fold: {fold_num}") + + if multi_GPU: + + task = next(iter(mb.tasks)) + assert (task.dataset_name == task_name) + task.load() + + train_inputs, train_outputs = task.get_train_and_val_data(fold_num) + test_inputs, test_outputs = task.get_test_data(fold_num, include_target=True) + + + # train_label = train_outputs.values + # print("=0: ", np.sum((train_label < 1e-13)) / train_label.shape[0]) + # print("< 1e-6: ", np.sum((train_label < 1e-6)) / train_label.shape[0]) + # print("1e-6 ~ 1e-2: ", np.sum(np.logical_and((train_label > 1e-6), (train_label < 1e-2))) / train_label.shape[0]) + # print("1e-2 ~ 1e-1: ", np.sum(np.logical_and((train_label > 1e-2), (train_label < 1e-1))) / train_label.shape[0]) + # print("1e-1 ~ 1: ", np.sum(np.logical_and((train_label > 1e-1), (train_label < 1))) / train_label.shape[0]) + # print("> 1: ", np.sum((train_label > 1)) / train_label.shape[0]) + + # import pdb; pdb.set_trace() + + target, predictions = train_dgl(config, + use_save=use_save, + mp_id_list=mp_id_list, + train_inputs=train_inputs, + train_outputs=train_outputs, + test_inputs=test_inputs, + test_outputs=test_outputs, + model_variant=model_variant) + + np.save(os.path.join(output_dir, f"result_fold_{fold_num}.npy"), predictions) + + else: + for task in mb.tasks: + task.load() + for fold in task.folds: + + train_inputs, train_outputs = task.get_train_and_val_data(fold) + test_inputs, test_outputs = task.get_test_data(fold, include_target=True) + + # import pdb; pdb.set_trace() + target, predictions = train_dgl(config, + use_save=use_save, + mp_id_list=mp_id_list, + train_inputs=train_inputs, + train_outputs=train_outputs, + test_inputs=test_inputs, + test_outputs=test_outputs, + model_variant=model_variant) + + # import pdb; pdb.set_trace() + # print("test set error", (target - test_outputs).sum()) + # import pdb; pdb.set_trace() + + # Predict on the testing data + # Your output should be a pandas series, numpy array, or python iterable + # where the array elements are floats or bools + # predictions = my_model.predict(test_inputs) + + # Record your data! + task.record(fold, predictions) + + # Save your results + mb.to_file(os.path.join(output_dir, f"{task_name}.json.gz")) + + + diff --git a/benchmarks/matbench_v0.1_eComFormer/utils.py b/benchmarks/matbench_v0.1_eComFormer/utils.py new file mode 100644 index 00000000..c9cce671 --- /dev/null +++ b/benchmarks/matbench_v0.1_eComFormer/utils.py @@ -0,0 +1,45 @@ +"""Shared pydantic settings configuration.""" +"""Implementation based on the template of ALIGNN.""" +import json +from pathlib import Path +from typing import Union +import matplotlib.pyplot as plt + +from pydantic import BaseSettings as PydanticBaseSettings + + +class BaseSettings(PydanticBaseSettings): + """Add configuration to default Pydantic BaseSettings.""" + + class Config: + """Configure BaseSettings behavior.""" + + extra = "forbid" + use_enum_values = True + env_prefix = "jv_" + + +def plot_learning_curve( + results_dir: Union[str, Path], key: str = "mae", plot_train: bool = False +): + """Plot learning curves based on json history files.""" + if isinstance(results_dir, str): + results_dir = Path(results_dir) + + with open(results_dir / "history_val.json", "r") as f: + val = json.load(f) + + p = plt.plot(val[key], label=results_dir.name) + + if plot_train: + # plot the training trace in the same color, lower opacity + with open(results_dir / "history_train.json", "r") as f: + train = json.load(f) + + c = p[0].get_color() + plt.plot(train[key], alpha=0.5, c=c) + + plt.xlabel("epochs") + plt.ylabel(key) + + return train, val diff --git a/benchmarks/matbench_v0.1_iComFormer/config.py b/benchmarks/matbench_v0.1_iComFormer/config.py new file mode 100644 index 00000000..0b191a30 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/config.py @@ -0,0 +1,195 @@ +"""Pydantic model for default configuration and validation.""" +"""Implementation based on the template of ALIGNN.""" + +import subprocess +from typing import Optional, Union +import os +from pydantic import root_validator + +# vfrom pydantic import Field, root_validator, validator +from pydantic.typing import Literal +from matformer.utils import BaseSettings +from matformer.models.pyg_att import MatformerConfig + +# from typing import List + +try: + VERSION = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() + ) +except Exception as exp: + VERSION = "NA" + pass + + +FEATURESET_SIZE = {"basic": 11, "atomic_number": 1, "cfid": 438, "cgcnn": 92} + + +TARGET_ENUM = Literal[ + "formation_energy_peratom", + "optb88vdw_bandgap", + "bulk_modulus_kv", + "shear_modulus_gv", + "mbj_bandgap", + "slme", + "magmom_oszicar", + "spillage", + "kpoint_length_unit", + "encut", + "optb88vdw_total_energy", + "epsx", + "epsy", + "epsz", + "mepsx", + "mepsy", + "mepsz", + "max_ir_mode", + "min_ir_mode", + "n-Seebeck", + "p-Seebeck", + "n-powerfact", + "p-powerfact", + "ncond", + "pcond", + "nkappa", + "pkappa", + "ehull", + "exfoliation_energy", + "dfpt_piezo_max_dielectric", + "dfpt_piezo_max_eij", + "dfpt_piezo_max_dij", + "gap pbe", + "e_form", + "e_hull", + "energy_per_atom", + "formation_energy_per_atom", + "band_gap", + "e_above_hull", + "mu_b", + "bulk modulus", + "shear modulus", + "elastic anisotropy", + "U0", + "HOMO", + "LUMO", + "R2", + "ZPVE", + "omega1", + "mu", + "alpha", + "homo", + "lumo", + "gap", + "r2", + "zpve", + "U", + "H", + "G", + "Cv", + "A", + "B", + "C", + "all", + "target", + "max_efg", + "avg_elec_mass", + "avg_hole_mass", + "_oqmd_band_gap", + "_oqmd_delta_e", + "_oqmd_stability", + "edos_up", + "pdos_elast", + "bandgap", + "energy_total", + "net_magmom", + "b3lyp_homo", + "b3lyp_lumo", + "b3lyp_gap", + "b3lyp_scharber_pce", + "b3lyp_scharber_voc", + "b3lyp_scharber_jsc", + "log_kd_ki", + "max_co2_adsp", + "min_co2_adsp", + "lcd", + "pld", + "void_fraction", + "surface_area_m2g", + "surface_area_m2cm3", + "indir_gap", + "f_enp", + "final_energy", + "energy_per_atom", + "label", +] + + +class TrainingConfig(BaseSettings): + """Training config defaults and validation.""" + + version: str = VERSION + + # dataset configuration + dataset: Literal[ + "dft_3d", + "megnet", + "mpf", + ] = "dft_3d" + target: TARGET_ENUM = "formation_energy_peratom" + atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn" + neighbor_strategy: Literal["k-nearest", "voronoi", "pairwise-k-nearest"] = "k-nearest" + id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid" + + # logging configuration + + # training configuration + random_seed: Optional[int] = 123 + classification_threshold: Optional[float] = None + n_val: Optional[int] = None + n_test: Optional[int] = None + n_train: Optional[int] = None + train_ratio: Optional[float] = 0.8 + val_ratio: Optional[float] = 0.1 + test_ratio: Optional[float] = 0.1 + target_multiplication_factor: Optional[float] = None + epochs: int = 300 + batch_size: int = 64 + weight_decay: float = 0 + learning_rate: float = 1e-2 + filename: str = "sample" + warmup_steps: int = 2000 + criterion: Literal["mse", "l1", "poisson", "zig"] = "mse" + optimizer: Literal["adamw", "sgd"] = "adamw" + scheduler: Literal["onecycle", "none", "step", "polynomial"] = "onecycle" + pin_memory: bool = False + save_dataloader: bool = False + write_checkpoint: bool = True + write_predictions: bool = True + store_outputs: bool = True + progress: bool = True + log_tensorboard: bool = False + standard_scalar_and_pca: bool = False + use_canonize: bool = True + num_workers: int = 2 + cutoff: float = 4.0 + max_neighbors: int = 12 + keep_data_order: bool = False + distributed: bool = False + n_early_stopping: Optional[int] = None # typically 50 + output_dir: str = os.path.abspath(".") # typically 50 + matrix_input: bool = False + pyg_input: bool = False + use_lattice: bool = False + use_angle: bool = False + + # model configuration + model = MatformerConfig(name="matformer") + print(model) + @root_validator() + def set_input_size(cls, values): + """Automatically configure node feature dimensionality.""" + values["model"].atom_input_features = FEATURESET_SIZE[ + values["atom_features"] + ] + + return values diff --git a/benchmarks/matbench_v0.1_iComFormer/data.py b/benchmarks/matbench_v0.1_iComFormer/data.py new file mode 100644 index 00000000..3e386584 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/data.py @@ -0,0 +1,664 @@ +"""Implementation based on the template of ALIGNN.""" + +import imp +import random +from pathlib import Path +from typing import Optional + +# from typing import Dict, List, Optional, Set, Tuple + +import os +import torch +import numpy as np +import pandas as pd +from jarvis.core.atoms import Atoms +from matformer.graphs import PygGraph, PygStructureDataset +# +from pymatgen.io.jarvis import JarvisAtomsAdaptor +from jarvis.db.figshare import data as jdata +from torch.utils.data import DataLoader +from tqdm import tqdm +import math +from jarvis.db.jsonutils import dumpjson +from pandarallel import pandarallel +pandarallel.initialize(progress_bar=True) +# from sklearn.pipeline import Pipeline +import pickle as pk + +from sklearn.preprocessing import StandardScaler + +# use pandas progress_apply +tqdm.pandas() + + +def load_dataset( + name: str = "dft_3d", + target=None, + limit: Optional[int] = None, + classification_threshold: Optional[float] = None, +): + """Load jarvis data.""" + d = jdata(name) + data = [] + for i in d: + if i[target] != "na" and not math.isnan(i[target]): + if classification_threshold is not None: + if i[target] <= classification_threshold: + i[target] = 0 + elif i[target] > classification_threshold: + i[target] = 1 + else: + raise ValueError( + "Check classification data type.", + i[target], + type(i[target]), + ) + data.append(i) + d = data + if limit is not None: + d = d[:limit] + d = pd.DataFrame(d) + return d + + +def mean_absolute_deviation(data, axis=None): + """Get Mean absolute deviation.""" + return np.mean(np.absolute(data - np.mean(data, axis)), axis) + + + +def load_pyg_graphs( + df: pd.DataFrame, + name: str = "dft_3d", + neighbor_strategy: str = "k-nearest", + cutoff: float = 8, + max_neighbors: int = 12, + cachedir: Optional[Path] = None, + use_canonize: bool = False, + use_lattice: bool = False, + use_angle: bool = False, +): + """Construct crystal graphs. + + Load only atomic number node features + and bond displacement vector edge features. + + Resulting graphs have scheme e.g. + ``` + Graph(num_nodes=12, num_edges=156, + ndata_schemes={'atom_features': Scheme(shape=(1,)} + edata_schemes={'r': Scheme(shape=(3,)}) + ``` + """ + + def atoms_to_graph(atoms): + """Convert structure dict to DGLGraph.""" + adaptor = JarvisAtomsAdaptor() + structure = adaptor.get_atoms(atoms) + return PygGraph.atom_dgl_multigraph( + structure, + neighbor_strategy=neighbor_strategy, + cutoff=cutoff, + atom_features="atomic_number", + max_neighbors=max_neighbors, + compute_line_graph=False, + use_canonize=use_canonize, + use_lattice=use_lattice, + use_angle=use_angle, + ) + + graphs = df["atoms"].parallel_apply(atoms_to_graph).values + # graphs = df["atoms"].apply(atoms_to_graph).values + + return graphs + + +def get_id_train_val_test( + total_size=1000, + split_seed=123, + train_ratio=None, + val_ratio=0.1, + test_ratio=0.1, + n_train=None, + n_test=None, + n_val=None, + keep_data_order=False, +): + """Get train, val, test IDs.""" + if ( + train_ratio is None + and val_ratio is not None + and test_ratio is not None + ): + if train_ratio is None: + assert val_ratio + test_ratio < 1 + train_ratio = 1 - val_ratio - test_ratio + print("Using rest of the dataset except the test and val sets.") + else: + assert train_ratio + val_ratio + test_ratio <= 1 + # indices = list(range(total_size)) + if n_train is None: + n_train = int(train_ratio * total_size) + if n_test is None: + n_test = int(test_ratio * total_size) + if n_val is None: + n_val = int(val_ratio * total_size) + ids = list(np.arange(total_size)) + if not keep_data_order: + random.seed(split_seed) + random.shuffle(ids) + if n_train + n_val + n_test > total_size: + raise ValueError( + "Check total number of samples.", + n_train + n_val + n_test, + ">", + total_size, + ) + + id_train = ids[:n_train] + id_val = ids[-(n_val + n_test) : -n_test] # noqa:E203 + id_test = ids[-n_test:] + return id_train, id_val, id_test + + +def get_torch_dataset( + dataset=[], + id_tag="jid", + target="", + neighbor_strategy="", + atom_features="", + use_canonize="", + name="", + line_graph="", + cutoff=8.0, + max_neighbors=12, + classification=False, + output_dir=".", + tmp_name="dataset", +): + """Get Torch Dataset.""" + df = pd.DataFrame(dataset) + # print("df", df) + vals = df[target].values + if target == "shear modulus" or target == "bulk modulus": + val_list = [vals[i].item() for i in range(len(vals))] + vals = val_list + print("data range", np.max(vals), np.min(vals)) + print("data mean and std", np.mean(vals), np.std(vals)) + f = open(os.path.join(output_dir, tmp_name + "_data_range"), "w") + line = "Max=" + str(np.max(vals)) + "\n" + f.write(line) + line = "Min=" + str(np.min(vals)) + "\n" + f.write(line) + f.close() + + graphs = load_graphs( + df, + name=name, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + cutoff=cutoff, + max_neighbors=max_neighbors, + ) + + data = StructureDataset( + df, + graphs, + target=target, + atom_features=atom_features, + line_graph=line_graph, + id_tag=id_tag, + classification=classification, + ) + return data + +def get_pyg_dataset( + dataset=[], + id_tag="jid", + target="", + neighbor_strategy="", + atom_features="", + use_canonize="", + name="", + line_graph="", + cutoff=8.0, + max_neighbors=12, + classification=False, + output_dir=".", + tmp_name="dataset", + use_lattice=False, + use_angle=False, + data_from='Jarvis', + use_save=False, + mean_train=None, + std_train=None, + now=False, # for test +): + """Get pyg Dataset.""" + df = pd.DataFrame(dataset) + # print("df", df) + # neighbor_strategy = "pairwise-k-nearest" + vals = df[target].values + if target == "shear modulus" or target == "bulk modulus": + val_list = [vals[i].item() for i in range(len(vals))] + vals = val_list + output_dir = "./saved_data/" + tmp_name + "test_graph_angle.pkl" # for fast test use + print("data range", np.max(vals), np.min(vals)) + print(output_dir) + print('graphs not saved') + graphs = load_pyg_graphs( + df, + name=name, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + cutoff=cutoff, + max_neighbors=max_neighbors, + use_lattice=use_lattice, + use_angle=use_angle, + ) + if mean_train == None: + mean_train = np.mean(vals) + std_train = np.std(vals) + data = PygStructureDataset( + df, + graphs, + target=target, + atom_features=atom_features, + line_graph=line_graph, + id_tag=id_tag, + classification=classification, + neighbor_strategy=neighbor_strategy, + mean_train=mean_train, + std_train=std_train, + ) + else: + data = PygStructureDataset( + df, + graphs, + target=target, + atom_features=atom_features, + line_graph=line_graph, + id_tag=id_tag, + classification=classification, + neighbor_strategy=neighbor_strategy, + mean_train=mean_train, + std_train=std_train, + ) + return data, mean_train, std_train + + +def get_train_val_loaders( + dataset: str = "dft_3d", + dataset_array=[], + target: str = "formation_energy_peratom", + atom_features: str = "cgcnn", + neighbor_strategy: str = "k-nearest", + n_train=None, + n_val=None, + n_test=None, + train_ratio=None, + val_ratio=0.1, + test_ratio=0.1, + batch_size: int = 5, + standardize: bool = False, + line_graph: bool = True, + split_seed: int = 123, + workers: int = 0, + pin_memory: bool = True, + save_dataloader: bool = False, + filename: str = "sample", + id_tag: str = "jid", + use_canonize: bool = False, + cutoff: float = 8.0, + max_neighbors: int = 12, + classification_threshold: Optional[float] = None, + target_multiplication_factor: Optional[float] = None, + standard_scalar_and_pca=False, + keep_data_order=False, + output_features=1, + output_dir=None, + matrix_input=False, + pyg_input=False, + use_lattice=False, + use_angle=False, + use_save=True, + mp_id_list=None, + train_inputs=None, + train_outputs=None, + test_inputs=None, + test_outputs=None, +): + """Help function to set up JARVIS train and val dataloaders.""" + # data loading + mean_train=None + std_train=None + assert (matrix_input and pyg_input) == False + + # train_sample = filename + "_train.data" + # val_sample = filename + "_val.data" + # test_sample = filename + "_test.data" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # if ( + # os.path.exists(train_sample) + # and os.path.exists(val_sample) + # and os.path.exists(test_sample) + # and save_dataloader + # ): + # print("Loading from saved file...") + # print("Make sure all the DataLoader params are same.") + # print("This module is made for debugging only.") + # train_loader = torch.load(train_sample) + # val_loader = torch.load(val_sample) + # test_loader = torch.load(test_sample) + # if train_loader.pin_memory != pin_memory: + # train_loader.pin_memory = pin_memory + # if test_loader.pin_memory != pin_memory: + # test_loader.pin_memory = pin_memory + # if val_loader.pin_memory != pin_memory: + # val_loader.pin_memory = pin_memory + # if train_loader.num_workers != workers: + # train_loader.num_workers = workers + # if test_loader.num_workers != workers: + # test_loader.num_workers = workers + # if val_loader.num_workers != workers: + # val_loader.num_workers = workers + # print("train", len(train_loader.dataset)) + # print("val", len(val_loader.dataset)) + # print("test", len(test_loader.dataset)) + # return ( + # train_loader, + # val_loader, + # test_loader, + # train_loader.dataset.prepare_batch, + # ) + # else: + # if not dataset_array: + # d = jdata(dataset) + # else: + # d = dataset_array + # # for ii, i in enumerate(pc_y): + # # d[ii][target] = pc_y[ii].tolist() + + # dat = [] + # if classification_threshold is not None: + # print( + # "Using ", + # classification_threshold, + # " for classifying ", + # target, + # " data.", + # ) + # print("Converting target data into 1 and 0.") + # all_targets = [] + + # # TODO:make an all key in qm9_dgl + # if dataset == "qm9_dgl" and target == "all": + # print("Making all qm9_dgl") + # tmp = [] + # for ii in d: + # ii["all"] = [ + # ii["mu"], + # ii["alpha"], + # ii["homo"], + # ii["lumo"], + # ii["gap"], + # ii["r2"], + # ii["zpve"], + # ii["U0"], + # ii["U"], + # ii["H"], + # ii["G"], + # ii["Cv"], + # ] + # tmp.append(ii) + # print("Made all qm9_dgl") + # d = tmp + # for i in d: + # if isinstance(i[target], list): # multioutput target + # all_targets.append(torch.tensor(i[target])) + # dat.append(i) + + # elif ( + # i[target] is not None + # and i[target] != "na" + # and not math.isnan(i[target]) + # ): + # if target_multiplication_factor is not None: + # i[target] = i[target] * target_multiplication_factor + # if classification_threshold is not None: + # if i[target] <= classification_threshold: + # i[target] = 0 + # elif i[target] > classification_threshold: + # i[target] = 1 + # else: + # raise ValueError( + # "Check classification data type.", + # i[target], + # type(i[target]), + # ) + # dat.append(i) + # all_targets.append(i[target]) + + + # if mp_id_list is not None: + # if mp_id_list == 'bulk': + # print('using mp bulk dataset') + # with open('/data/keqiangyan/bulk_shear/bulk_megnet_train.pkl', 'rb') as f: + # dataset_train = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/bulk_megnet_val.pkl', 'rb') as f: + # dataset_val = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/bulk_megnet_test.pkl', 'rb') as f: + # dataset_test = pk.load(f) + + # if mp_id_list == 'shear': + # print('using mp shear dataset') + # with open('/data/keqiangyan/bulk_shear/shear_megnet_train.pkl', 'rb') as f: + # dataset_train = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/shear_megnet_val.pkl', 'rb') as f: + # dataset_val = pk.load(f) + # with open('/data/keqiangyan/bulk_shear/shear_megnet_test.pkl', 'rb') as f: + # dataset_test = pk.load(f) + + # else: + # id_train, id_val, id_test = get_id_train_val_test( + # total_size=len(dat), + # split_seed=split_seed, + # train_ratio=train_ratio, + # val_ratio=val_ratio, + # test_ratio=test_ratio, + # n_train=n_train, + # n_test=n_test, + # n_val=n_val, + # keep_data_order=keep_data_order, + # ) + # ids_train_val_test = {} + # ids_train_val_test["id_train"] = [dat[i][id_tag] for i in id_train] + # ids_train_val_test["id_val"] = [dat[i][id_tag] for i in id_val] + # ids_train_val_test["id_test"] = [dat[i][id_tag] for i in id_test] + # dumpjson( + # data=ids_train_val_test, + # filename=os.path.join(output_dir, "ids_train_val_test.json"), + # ) + # dataset_train = [dat[x] for x in id_train] + # dataset_val = [dat[x] for x in id_val] + # dataset_test = [dat[x] for x in id_test] + + dataset_train = [] + dataset_val = [] + dataset_test = [] + for i in range(len(train_inputs)): + dataset_train.append({"atoms":train_inputs[i], "label": train_outputs[i]}) + + for i in range(len(test_inputs)): + dataset_val.append({"atoms":test_inputs[i], "label": test_outputs[i]}) + dataset_test.append({"atoms":test_inputs[i], "label": test_outputs[i]}) + + print("Number of train data: ", len(dataset_train)) + print("Number of test data: ", len(dataset_test)) + + # import pdb; pdb.set_trace() + + # if standard_scalar_and_pca: + # y_data = [i[target] for i in dataset_train] + # # pipe = Pipeline([('scale', StandardScaler())]) + # if not isinstance(y_data[0], list): + # print("Running StandardScalar") + # y_data = np.array(y_data).reshape(-1, 1) + # sc = StandardScaler() + + # sc.fit(y_data) + # print("Mean", sc.mean_) + # print("Variance", sc.var_) + # try: + # print("New max", max(y_data)) + # print("New min", min(y_data)) + # except Exception as exp: + # print(exp) + # pass + + # pk.dump(sc, open(os.path.join(output_dir, "sc.pkl"), "wb")) + + # if classification_threshold is None: + # try: + # from sklearn.metrics import mean_absolute_error + + # print("MAX val:", max(all_targets)) + # print("MIN val:", min(all_targets)) + # print("MAD:", mean_absolute_deviation(all_targets)) + # try: + # f = open(os.path.join(output_dir, "mad"), "w") + # line = "MAX val:" + str(max(all_targets)) + "\n" + # line += "MIN val:" + str(min(all_targets)) + "\n" + # line += ( + # "MAD val:" + # + str(mean_absolute_deviation(all_targets)) + # + "\n" + # ) + # f.write(line) + # f.close() + # except Exception as exp: + # print("Cannot write mad", exp) + # pass + # # Random model precited value + # x_bar = np.mean(np.array([i[target] for i in dataset_train])) + # baseline_mae = mean_absolute_error( + # np.array([i[target] for i in dataset_test]), + # np.array([x_bar for i in dataset_test]), + # ) + # print("Baseline MAE:", baseline_mae) + # except Exception as exp: + # print("Data error", exp) + # pass + + train_data, mean_train, std_train = get_pyg_dataset( + dataset=dataset_train, + id_tag=id_tag, + atom_features=atom_features, + target=target, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + name=dataset, + line_graph=line_graph, + cutoff=cutoff, + max_neighbors=max_neighbors, + classification=classification_threshold is not None, + output_dir=output_dir, + tmp_name="train_data", + use_lattice=use_lattice, + use_angle=use_angle, + use_save=False, + ) + # import pdb; pdb.set_trace() + val_data,_,_ = get_pyg_dataset( + dataset=dataset_val, + id_tag=id_tag, + atom_features=atom_features, + target=target, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + name=dataset, + line_graph=line_graph, + cutoff=cutoff, + max_neighbors=max_neighbors, + classification=classification_threshold is not None, + output_dir=output_dir, + tmp_name="val_data", + use_lattice=use_lattice, + use_angle=use_angle, + use_save=False, + mean_train=mean_train, + std_train=std_train, + ) + test_data,_,_ = get_pyg_dataset( + dataset=dataset_test, + id_tag=id_tag, + atom_features=atom_features, + target=target, + neighbor_strategy=neighbor_strategy, + use_canonize=use_canonize, + name=dataset, + line_graph=line_graph, + cutoff=cutoff, + max_neighbors=max_neighbors, + classification=classification_threshold is not None, + output_dir=output_dir, + tmp_name="test_data", + use_lattice=use_lattice, + use_angle=use_angle, + use_save=False, + mean_train=mean_train, + std_train=std_train, + ) + + collate_fn = train_data.collate + if line_graph: + collate_fn = train_data.collate_line_graph + + # use a regular pytorch dataloader + train_loader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn, + drop_last=True, + num_workers=workers, + pin_memory=pin_memory, + ) + + val_loader = DataLoader( + val_data, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn, + drop_last=True, + num_workers=workers, + pin_memory=pin_memory, + ) + + test_loader = DataLoader( + test_data, + batch_size=1, + shuffle=False, + collate_fn=collate_fn, + drop_last=False, + num_workers=workers, + pin_memory=pin_memory, + ) + if save_dataloader: + torch.save(train_loader, train_sample) + torch.save(val_loader, val_sample) + torch.save(test_loader, test_sample) + + print("n_train:", len(train_loader.dataset)) + print("n_val:", len(val_loader.dataset)) + print("n_test:", len(test_loader.dataset)) + return ( + train_loader, + val_loader, + test_loader, + train_loader.dataset.prepare_batch, + mean_train, + std_train, + ) + diff --git a/benchmarks/matbench_v0.1_iComFormer/features.py b/benchmarks/matbench_v0.1_iComFormer/features.py new file mode 100644 index 00000000..8e4edcb7 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/features.py @@ -0,0 +1,265 @@ +# Based on the code from: https://github.com/klicperajo/dimenet, +# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet_utils.py + + +import numpy as np +from scipy.optimize import brentq +from scipy import special as sp +import torch +from math import pi as PI + +import sympy as sym + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +def Jn(r, n): + return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) + + +def Jn_zeros(n, k): + zerosj = np.zeros((n, k), dtype='float32') + zerosj[0] = np.arange(1, k + 1) * np.pi + points = np.arange(1, k + n) * np.pi + racines = np.zeros(k + n - 1, dtype='float32') + for i in range(1, n): + for j in range(k + n - 1 - i): + foo = brentq(Jn, points[j], points[j + 1], (i, )) + racines[j] = foo + points = racines + zerosj[i][:k] = racines[:k] + + return zerosj + + +def spherical_bessel_formulas(n): + x = sym.symbols('x') + + f = [sym.sin(x) / x] + a = sym.sin(x) / x + for i in range(1, n): + b = sym.diff(a, x) / x + f += [sym.simplify(b * (-x)**i)] + a = sym.simplify(b) + return f + + +def bessel_basis(n, k): + zeros = Jn_zeros(n, k) + normalizer = [] + for order in range(n): + normalizer_tmp = [] + for i in range(k): + normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] + normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 + normalizer += [normalizer_tmp] + + f = spherical_bessel_formulas(n) + x = sym.symbols('x') + bess_basis = [] + for order in range(n): + bess_basis_tmp = [] + for i in range(k): + bess_basis_tmp += [ + sym.simplify(normalizer[order][i] * + f[order].subs(x, zeros[order, i] * x)) + ] + bess_basis += [bess_basis_tmp] + return bess_basis + + +def sph_harm_prefactor(k, m): + return ((2 * k + 1) * np.math.factorial(k - abs(m)) / + (4 * np.pi * np.math.factorial(k + abs(m))))**0.5 + + +def associated_legendre_polynomials(k, zero_m_only=True): + z = sym.symbols('z') + P_l_m = [[0] * (j + 1) for j in range(k)] + + P_l_m[0][0] = 1 + if k > 0: + P_l_m[1][0] = z + + for j in range(2, k): + P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - + (j - 1) * P_l_m[j - 2][0]) / j) + if not zero_m_only: + for i in range(1, k): + P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) + if i + 1 < k: + P_l_m[i + 1][i] = sym.simplify( + (2 * i + 1) * z * P_l_m[i][i]) + for j in range(i + 2, k): + P_l_m[j][i] = sym.simplify( + ((2 * j - 1) * z * P_l_m[j - 1][i] - + (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) + + return P_l_m + + +def real_sph_harm(l, zero_m_only=False, spherical_coordinates=True): + """ + Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). + Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. + """ + if not zero_m_only: + x = sym.symbols('x') + y = sym.symbols('y') + S_m = [x*0] + C_m = [1+0*x] + # S_m = [0] + # C_m = [1] + for i in range(1, l): + x = sym.symbols('x') + y = sym.symbols('y') + S_m += [x*S_m[i-1] + y*C_m[i-1]] + C_m += [x*C_m[i-1] - y*S_m[i-1]] + + P_l_m = associated_legendre_polynomials(l, zero_m_only) + if spherical_coordinates: + theta = sym.symbols('theta') + z = sym.symbols('z') + for i in range(len(P_l_m)): + for j in range(len(P_l_m[i])): + if type(P_l_m[i][j]) != int: + P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) + if not zero_m_only: + phi = sym.symbols('phi') + for i in range(len(S_m)): + S_m[i] = S_m[i].subs(x, sym.sin( + theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) + for i in range(len(C_m)): + C_m[i] = C_m[i].subs(x, sym.sin( + theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) + + Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] + for i in range(l): + Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) + + if not zero_m_only: + for i in range(1, l): + for j in range(1, i + 1): + Y_func_l_m[i][j] = sym.simplify( + 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) + for i in range(1, l): + for j in range(1, i + 1): + Y_func_l_m[i][-j] = sym.simplify( + 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) + + return Y_func_l_m + + +class Envelope(torch.nn.Module): + def __init__(self, exponent): + super(Envelope, self).__init__() + self.p = exponent + 1 + self.a = -(self.p + 1) * (self.p + 2) / 2 + self.b = self.p * (self.p + 2) + self.c = -self.p * (self.p + 1) / 2 + + def forward(self, x): + p, a, b, c = self.p, self.a, self.b, self.c + x_pow_p0 = x.pow(p - 1) + x_pow_p1 = x_pow_p0 * x + x_pow_p2 = x_pow_p1 * x + return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 + + +class dist_emb(torch.nn.Module): + def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5): + super(dist_emb, self).__init__() + self.cutoff = cutoff + self.envelope = Envelope(envelope_exponent) + + self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) + + self.reset_parameters() + + def reset_parameters(self): + torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) + + def forward(self, dist): + dist = dist.unsqueeze(-1) / self.cutoff + return self.envelope(dist) * (self.freq * dist).sin() + + +class angle_emb_mp(torch.nn.Module): + def __init__(self, num_spherical=3, num_radial=30, cutoff=8.0, + envelope_exponent=5): + super(angle_emb_mp, self).__init__() + assert num_radial <= 64 + self.num_spherical = num_spherical + self.num_radial = num_radial + self.cutoff = cutoff + # self.envelope = Envelope(envelope_exponent) + + bessel_forms = bessel_basis(num_spherical, num_radial) + sph_harm_forms = real_sph_harm(num_spherical) + self.sph_funcs = [] + self.bessel_funcs = [] + + x, theta = sym.symbols('x theta') + modules = {'sin': torch.sin, 'cos': torch.cos} + for i in range(num_spherical): + if i == 0: + sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) + self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) + else: + sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) + self.sph_funcs.append(sph) + for j in range(num_radial): + bessel = sym.lambdify([x], bessel_forms[i][j], modules) + self.bessel_funcs.append(bessel) + + def forward(self, dist, angle, idx_kj): + dist = dist / self.cutoff + rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) + # rbf = self.envelope(dist).unsqueeze(-1) * rbf + + cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) + + n, k = self.num_spherical, self.num_radial + out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) + return out + + +class torsion_emb(torch.nn.Module): + def __init__(self, num_spherical, num_radial, cutoff=5.0, + envelope_exponent=5): + super(torsion_emb, self).__init__() + assert num_radial <= 64 + self.num_spherical = num_spherical # + self.num_radial = num_radial + self.cutoff = cutoff + # self.envelope = Envelope(envelope_exponent) + + bessel_forms = bessel_basis(num_spherical, num_radial) + sph_harm_forms = real_sph_harm(num_spherical, zero_m_only=False) + self.sph_funcs = [] + self.bessel_funcs = [] + + x = sym.symbols('x') + theta = sym.symbols('theta') + phi = sym.symbols('phi') + modules = {'sin': torch.sin, 'cos': torch.cos} + for i in range(self.num_spherical): + if i == 0: + sph1 = sym.lambdify([theta, phi], sph_harm_forms[i][0], modules) + self.sph_funcs.append(lambda x, y: torch.zeros_like(x) + torch.zeros_like(y) + sph1(0,0)) #torch.zeros_like(x) + torch.zeros_like(y) + else: + for k in range(-i, i + 1): + sph = sym.lambdify([theta, phi], sph_harm_forms[i][k+i], modules) + self.sph_funcs.append(sph) + for j in range(self.num_radial): + bessel = sym.lambdify([x], bessel_forms[i][j], modules) + self.bessel_funcs.append(bessel) + + def forward(self, dist, angle, phi, idx_kj): + dist = dist / self.cutoff + rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) + cbf = torch.stack([f(angle, phi) for f in self.sph_funcs], dim=1) + + n, k = self.num_spherical, self.num_radial + out = (rbf[idx_kj].view(-1, 1, n, k) * cbf.view(-1, n, n, 1)).view(-1, n * n * k) + return out + diff --git a/benchmarks/matbench_v0.1_iComFormer/graphs.py b/benchmarks/matbench_v0.1_iComFormer/graphs.py new file mode 100644 index 00000000..2032befd --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/graphs.py @@ -0,0 +1,575 @@ +"""Module to generate networkx graphs.""" +"""Implementation based on the template of ALIGNN.""" +from multiprocessing.context import ForkContext +from re import X +import numpy as np +import pandas as pd +from jarvis.core.specie import chem_data, get_node_attributes + +# from jarvis.core.atoms import Atoms +from collections import defaultdict +from typing import List, Tuple, Sequence, Optional +import torch +from torch_geometric.data import Data +from torch_geometric.transforms import LineGraph +from torch_geometric.data.batch import Batch +import itertools + +try: + import torch + from tqdm import tqdm +except Exception as exp: + print("torch/tqdm is not installed.", exp) + pass + + +def angle_from_array(a, b, lattice): + a_new = np.dot(a, lattice) + b_new = np.dot(b, lattice) + assert a_new.shape == a.shape + value = sum(a_new * b_new) + length = (sum(a_new ** 2) ** 0.5) * (sum(b_new ** 2) ** 0.5) + cos = value / length + angle = np.arccos(cos) + return angle / np.pi * 180.0 + +def correct_coord_sys(a, b, c, lattice): + a_new = np.dot(a, lattice) + b_new = np.dot(b, lattice) + c_new = np.dot(c, lattice) + assert a_new.shape == a.shape + plane_vec = np.cross(a_new, b_new) + value = sum(plane_vec * c_new) + length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5) + cos = value / length + angle = np.arccos(cos) + return (angle / np.pi * 180.0 <= 90.0) + +def same_line(a, b): + a_new = a / (sum(a ** 2) ** 0.5) + b_new = b / (sum(b ** 2) ** 0.5) + flag = False + if abs(sum(a_new * b_new) - 1.0) < 1e-5: + flag = True + elif abs(sum(a_new * b_new) + 1.0) < 1e-5: + flag = True + else: + flag = False + return flag + +def same_plane(a, b, c): + flag = False + if abs(np.dot(np.cross(a, b), c)) < 1e-5: + flag = True + return flag + +# pyg dataset +class PygStructureDataset(torch.utils.data.Dataset): + """Dataset of crystal DGLGraphs.""" + + def __init__( + self, + df: pd.DataFrame, + graphs: Sequence[Data], + target: str, + atom_features="atomic_number", + transform=None, + line_graph=False, + classification=False, + id_tag="jid", + neighbor_strategy="", + nolinegraph=False, + mean_train=None, + std_train=None, + ): + """Pytorch Dataset for atomistic graphs. + + `df`: pandas dataframe from e.g. jarvis.db.figshare.data + `graphs`: DGLGraph representations corresponding to rows in `df` + `target`: key for label column in `df` + """ + self.df = df + self.graphs = graphs + self.target = target + self.line_graph = line_graph + + # self.ids = self.df[id_tag] + self.atoms = self.df['atoms'] + self.labels = torch.tensor(self.df[target]).type( + torch.get_default_dtype() + ) + print("mean %f std %f"%(self.labels.mean(), self.labels.std())) + if mean_train == None: + mean = self.labels.mean() + std = self.labels.std() + self.labels = (self.labels - mean) / std + print("normalize using training mean but shall not be used here %f and std %f" % (mean, std)) + else: + self.labels = (self.labels - mean_train) / std_train + print("normalize using training mean %f and std %f" % (mean_train, std_train)) + + self.transform = transform + + features = self._get_attribute_lookup(atom_features) + + # load selected node representation + # assume graphs contain atomic number in g.ndata["atom_features"] + for g in graphs: + z = g.x + g.atomic_number = z + z = z.type(torch.IntTensor).squeeze() + f = torch.tensor(features[z]).type(torch.FloatTensor) + if g.x.size(0) == 1: + f = f.unsqueeze(0) + g.x = f + + self.prepare_batch = prepare_pyg_batch + if line_graph: + self.prepare_batch = prepare_pyg_line_graph_batch + print("building line graphs") + # if not nolinegraph: + # self.line_graphs = [] + # self.graphs = [] + # for g in tqdm(graphs): + # linegraph_trans = LineGraph(force_directed=True) + # g_new = Data() + # g_new.x, g_new.edge_index, g_new.edge_attr, g_new.edge_type = g.x, g.edge_index, g.edge_attr, g.edge_type + # try: + # lg = linegraph_trans(g) + # except Exception as exp: + # print(g.x, g.edge_attr, exp) + # pass + # lg.edge_attr = pyg_compute_bond_cosines(lg) # old cosine emb + # # lg.edge_attr = pyg_compute_bond_angle(lg) + # self.graphs.append(g_new) + # self.line_graphs.append(lg) + # else: + # + self.graphs = [] + for g in tqdm(graphs): + g.edge_attr = g.edge_attr.float() + self.graphs.append(g) + self.line_graphs = self.graphs + + + if classification: + self.labels = self.labels.view(-1).long() + print("Classification dataset.", self.labels) + + @staticmethod + def _get_attribute_lookup(atom_features: str = "cgcnn"): + """Build a lookup array indexed by atomic number.""" + max_z = max(v["Z"] for v in chem_data.values()) + + # get feature shape (referencing Carbon) + template = get_node_attributes("C", atom_features) + + features = np.zeros((1 + max_z, len(template))) + + for element, v in chem_data.items(): + z = v["Z"] + x = get_node_attributes(element, atom_features) + + if x is not None: + features[z, :] = x + + return features + + def __len__(self): + """Get length.""" + return self.labels.shape[0] + + def __getitem__(self, idx): + """Get StructureDataset sample.""" + g = self.graphs[idx] + label = self.labels[idx] + + if self.line_graph: + return g, self.line_graphs[idx], label, label + + return g, label + + def setup_standardizer(self, ids): + """Atom-wise feature standardization transform.""" + x = torch.cat( + [ + g.x + for idx, g in enumerate(self.graphs) + if idx in ids + ] + ) + self.atom_feature_mean = x.mean(0) + self.atom_feature_std = x.std(0) + + self.transform = PygStandardize( + self.atom_feature_mean, self.atom_feature_std + ) + + @staticmethod + def collate(samples: List[Tuple[Data, torch.Tensor]]): + """Dataloader helper to batch graphs cross `samples`.""" + graphs, labels = map(list, zip(*samples)) + batched_graph = Batch.from_data_list(graphs) + return batched_graph, torch.tensor(labels) + + @staticmethod + def collate_line_graph( + samples: List[Tuple[Data, Data, torch.Tensor, torch.Tensor]] + ): + """Dataloader helper to batch graphs cross `samples`.""" + graphs, line_graphs, lattice, labels = map(list, zip(*samples)) + batched_graph = Batch.from_data_list(graphs) + batched_line_graph = Batch.from_data_list(line_graphs) + if len(labels[0].size()) > 0: + return batched_graph, batched_line_graph, torch.cat([i.unsqueeze(0) for i in lattice]), torch.stack(labels) + else: + return batched_graph, batched_line_graph, torch.cat([i.unsqueeze(0) for i in lattice]), torch.tensor(labels) + +def canonize_edge( + src_id, + dst_id, + src_image, + dst_image, +): + """Compute canonical edge representation. + + Sort vertex ids + shift periodic images so the first vertex is in (0,0,0) image + """ + # store directed edges src_id <= dst_id + if dst_id < src_id: + src_id, dst_id = dst_id, src_id + src_image, dst_image = dst_image, src_image + + # shift periodic images so that src is in (0,0,0) image + if not np.array_equal(src_image, (0, 0, 0)): + shift = src_image + src_image = tuple(np.subtract(src_image, shift)) + dst_image = tuple(np.subtract(dst_image, shift)) + + assert src_image == (0, 0, 0) + + return src_id, dst_id, src_image, dst_image + + +def nearest_neighbor_edges_submit( + atoms=None, + cutoff=8, + max_neighbors=12, + id=None, + use_canonize=False, + use_lattice=False, + use_angle=False, +): + """Construct k-NN edge list.""" + # returns List[List[Tuple[site, distance, index, image]]] + lat = atoms.lattice + all_neighbors_now = atoms.get_all_neighbors(r=cutoff) + min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors_now) + + attempt = 0 + if min_nbrs < max_neighbors: + lat = atoms.lattice + if cutoff < max(lat.a, lat.b, lat.c): + r_cut = max(lat.a, lat.b, lat.c) + else: + r_cut = 2 * cutoff + attempt += 1 + return nearest_neighbor_edges_submit( + atoms=atoms, + use_canonize=use_canonize, + cutoff=r_cut, + max_neighbors=max_neighbors, + id=id, + use_lattice=use_lattice, + ) + + edges = defaultdict(set) + # lattice correction process + r_cut = max(lat.a, lat.b, lat.c) + 1e-2 + all_neighbors = atoms.get_all_neighbors(r=r_cut) + neighborlist = all_neighbors[0] + neighborlist = sorted(neighborlist, key=lambda x: x[2]) + ids = np.array([nbr[1] for nbr in neighborlist]) + images = np.array([nbr[3] for nbr in neighborlist]) + images = images[ids == 0] + lat1 = images[0] + # finding lat2 + start = 1 + for i in range(start, len(images)): + lat2 = images[i] + if not same_line(lat1, lat2): + start = i + break + # finding lat3 + for i in range(start, len(images)): + lat3 = images[i] + if not same_plane(lat1, lat2, lat3): + break + # find the invariant corner + if angle_from_array(lat1,lat2,lat.matrix) > 90.0: + lat2 = - lat2 + if angle_from_array(lat1,lat3,lat.matrix) > 90.0: + lat3 = - lat3 + # find the invariant coord system + if not correct_coord_sys(lat1, lat2, lat3, lat.matrix): + lat1 = - lat1 + lat2 = - lat2 + lat3 = - lat3 + + # if not correct_coord_sys(lat1, lat2, lat3, lat.matrix): + # print(lat1, lat2, lat3) + # lattice correction end + for site_idx, neighborlist in enumerate(all_neighbors_now): + + # sort on distance + neighborlist = sorted(neighborlist, key=lambda x: x[2]) + distances = np.array([nbr[2] for nbr in neighborlist]) + ids = np.array([nbr[1] for nbr in neighborlist]) + images = np.array([nbr[3] for nbr in neighborlist]) + + # find the distance to the k-th nearest neighbor + max_dist = distances[max_neighbors - 1] + ids = ids[distances <= max_dist] + images = images[distances <= max_dist] + distances = distances[distances <= max_dist] + for dst, image in zip(ids, images): + src_id, dst_id, src_image, dst_image = canonize_edge( + site_idx, dst, (0, 0, 0), tuple(image) + ) + if use_canonize: + edges[(src_id, dst_id)].add(dst_image) + else: + edges[(site_idx, dst)].add(tuple(image)) + + if use_lattice: + edges[(site_idx, site_idx)].add(tuple(lat1)) + edges[(site_idx, site_idx)].add(tuple(lat2)) + edges[(site_idx, site_idx)].add(tuple(lat3)) + + return edges, lat1, lat2, lat3 + + +def compute_bond_cosine(v1, v2): + """Compute bond angle cosines from bond displacement vectors.""" + v1 = torch.tensor(v1).type(torch.get_default_dtype()) + v2 = torch.tensor(v2).type(torch.get_default_dtype()) + bond_cosine = torch.sum(v1 * v2) / ( + torch.norm(v1) * torch.norm(v2) + ) + bond_cosine = torch.clamp(bond_cosine, -1, 1) + return bond_cosine + + +def build_undirected_edgedata( + atoms=None, + edges={}, + a=None, + b=None, + c=None, +): + """Build undirected graph data from edge set. + + edges: dictionary mapping (src_id, dst_id) to set of dst_image + r: cartesian displacement vector from src -> dst + """ + # second pass: construct *undirected* graph + # import pprint + u, v, r, l, nei, angle, atom_lat = [], [], [], [], [], [], [] + v1, v2, v3 = atoms.lattice.cart_coords(a), atoms.lattice.cart_coords(b), atoms.lattice.cart_coords(c) + atom_lat.append([v1, v2, v3]) + for (src_id, dst_id), images in edges.items(): + + for dst_image in images: + # fractional coordinate for periodic image of dst + dst_coord = atoms.frac_coords[dst_id] + dst_image + # cartesian displacement vector pointing from src -> dst + d = atoms.lattice.cart_coords( + dst_coord - atoms.frac_coords[src_id] + ) + for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]: + u.append(uu) + v.append(vv) + r.append(dd) + nei.append([v1, v2, v3]) + # angle.append([compute_bond_cosine(dd, v1), compute_bond_cosine(dd, v2), compute_bond_cosine(dd, v3)]) + # if np.linalg.norm(d)!=0: + # print ('jv',dst_image,d) + + u = torch.tensor(u) + v = torch.tensor(v) + r = torch.tensor(r).type(torch.get_default_dtype()) + l = torch.tensor(l).type(torch.int) + nei = torch.tensor(nei).type(torch.get_default_dtype()) + atom_lat = torch.tensor(atom_lat).type(torch.get_default_dtype()) + # nei_angles = torch.tensor(angle).type(torch.get_default_dtype()) + return u, v, r, l, nei, atom_lat + + +class PygGraph(object): + """Generate a graph object.""" + + def __init__( + self, + nodes=[], + node_attributes=[], + edges=[], + edge_attributes=[], + color_map=None, + labels=None, + ): + """ + Initialize the graph object. + + Args: + nodes: IDs of the graph nodes as integer array. + + node_attributes: node features as multi-dimensional array. + + edges: connectivity as a (u,v) pair where u is + the source index and v the destination ID. + + edge_attributes: attributes for each connectivity. + as simple as euclidean distances. + """ + self.nodes = nodes + self.node_attributes = node_attributes + self.edges = edges + self.edge_attributes = edge_attributes + self.color_map = color_map + self.labels = labels + + @staticmethod + def atom_dgl_multigraph( + atoms=None, + neighbor_strategy="k-nearest", + cutoff=4.0, + max_neighbors=12, + atom_features="cgcnn", + max_attempts=3, + id: Optional[str] = None, + compute_line_graph: bool = True, + use_canonize: bool = False, + use_lattice: bool = False, + use_angle: bool = False, + ): + if neighbor_strategy == "k-nearest": + edges, a, b, c = nearest_neighbor_edges_submit( + atoms=atoms, + cutoff=cutoff, + max_neighbors=max_neighbors, + id=id, + use_canonize=use_canonize, + use_lattice=use_lattice, + use_angle=use_angle, + ) + u, v, r, l, nei, atom_lat = build_undirected_edgedata(atoms, edges, a, b, c) + elif neighbor_strategy == "pairwise-k-nearest": + u, v, r = pair_nearest_neighbor_edges( + atoms=atoms, + pair_wise_distances=2, + use_lattice=use_lattice, + use_angle=use_angle, + ) + else: + raise ValueError("Not implemented yet", neighbor_strategy) + + + # build up atom attribute tensor + sps_features = [] + for ii, s in enumerate(atoms.elements): + feat = list(get_node_attributes(s, atom_features=atom_features)) + sps_features.append(feat) + sps_features = np.array(sps_features) + node_features = torch.tensor(sps_features).type( + torch.get_default_dtype() + ) + atom_lat = atom_lat.repeat(node_features.shape[0],1,1) + edge_index = torch.cat((u.unsqueeze(0), v.unsqueeze(0)), dim=0).long() + g = Data(x=node_features, edge_index=edge_index, edge_attr=r, edge_type=l, edge_nei=nei, atom_lat=atom_lat) + + return g + +def pyg_compute_bond_cosines(lg): + """Compute bond angle cosines from bond displacement vectors.""" + # line graph edge: (a, b), (b, c) + # `a -> b -> c` + # use law of cosines to compute angles cosines + # negate src bond so displacements are like `a <- b -> c` + # cos(theta) = ba \dot bc / (||ba|| ||bc||) + src, dst = lg.edge_index + x = lg.x + r1 = -x[src] + r2 = x[dst] + bond_cosine = torch.sum(r1 * r2, dim=1) / ( + torch.norm(r1, dim=1) * torch.norm(r2, dim=1) + ) + bond_cosine = torch.clamp(bond_cosine, -1, 1) + return bond_cosine + +def pyg_compute_bond_angle(lg): + """Compute bond angle from bond displacement vectors.""" + # line graph edge: (a, b), (b, c) + # `a -> b -> c` + src, dst = lg.edge_index + x = lg.x + r1 = -x[src] + r2 = x[dst] + a = (r1 * r2).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk| + b = torch.cross(r1, r2).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk| + angle = torch.atan2(b, a) + return angle + + + +class PygStandardize(torch.nn.Module): + """Standardize atom_features: subtract mean and divide by std.""" + + def __init__(self, mean: torch.Tensor, std: torch.Tensor): + """Register featurewise mean and standard deviation.""" + super().__init__() + self.mean = mean + self.std = std + + def forward(self, g: Data): + """Apply standardization to atom_features.""" + h = g.x + g.x = (h - self.mean) / self.std + return g + + + +def prepare_pyg_batch( + batch: Tuple[Data, torch.Tensor], device=None, non_blocking=False +): + """Send batched dgl crystal graph to device.""" + g, t = batch + batch = ( + g.to(device), + t.to(device, non_blocking=non_blocking), + ) + + return batch + + +def prepare_pyg_line_graph_batch( + batch: Tuple[Tuple[Data, Data, torch.Tensor], torch.Tensor], + device=None, + non_blocking=False, +): + """Send line graph batch to device. + + Note: the batch is a nested tuple, with the graph and line graph together + """ + g, lg, lattice, t = batch + batch = ( + ( + g.to(device), + lg.to(device), + lattice.to(device, non_blocking=non_blocking), + ), + t.to(device, non_blocking=non_blocking), + ) + + return batch + diff --git a/benchmarks/matbench_v0.1_iComFormer/info.json b/benchmarks/matbench_v0.1_iComFormer/info.json new file mode 100644 index 00000000..e060a08f --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/info.json @@ -0,0 +1,14 @@ +{ + "authors": "Keqiang Yan, Cong Fu, Xiaofeng Qian, Xiaoning Qian, Shuiwang Ji", + "algorithm": "iComFormer", + "algorithm_long": "Complete and efficient graph transformer for materials property prediction", + "bibtex_refs": "@inproceedings{ \n yan2024complete, \n title={Complete and Efficient Graph Transformers for Crystal Material Property Prediction},\n author={Keqiang Yan and Cong Fu and Xiaofeng Qian and Xiaoning Qian and Shuiwang Ji},\n booktitle={The Twelfth International Conference on Learning Representations},\n year={2024},\n url={https://openreview.net/forum?id=BnQY9XiRAS}\n}", + "notes": "This is the invariant version of ComFormer", + "requirements": { + "python": [ + "pytorch==1.13.1", + "torch_geometric==2.3.0", + "matbench==0.1.0, pymatgen=2023.3.23" + ] + } +} \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_iComFormer/models/__init__.py b/benchmarks/matbench_v0.1_iComFormer/models/__init__.py new file mode 100644 index 00000000..ccc4f536 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/models/__init__.py @@ -0,0 +1 @@ +"""Graph neural network implementations.""" diff --git a/benchmarks/matbench_v0.1_iComFormer/models/backup.py b/benchmarks/matbench_v0.1_iComFormer/models/backup.py new file mode 100644 index 00000000..b027ca27 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/models/backup.py @@ -0,0 +1,801 @@ +class MPNNConv(MessagePassing): + """Implements the message passing layer from + `"Crystal Graph Convolutional Neural Networks for an + Accurate and Interpretable Prediction of Material Properties" + `. + """ + + def init(self, fc_features): + super(MPNNConv, self).init(node_dim=0) + self.bn = nn.BatchNorm1d(fc_features) + self.bn_interaction = nn.BatchNorm1d(fc_features) + self.nonlinear_full = nn.Sequential( + nn.Linear(3 * fc_features, fc_features), + nn.SiLU(), + nn.Linear(fc_features, fc_features) + ) + self.nonlinear = nn.Sequential( + nn.Linear(3 * fc_features, fc_features), + nn.SiLU(), + nn.Linear(fc_features, fc_features), + ) + + def forward(self, x, edge_index, edge_attr): + """ + Arguments: + x has shape [num_nodes, node_feat_size] + edge_index has shape [2, num_edges] + edge_attr is [num_edges, edge_feat_size] + """ + + out = self.propagate( + edge_index, x=x, edge_attr=edge_attr, size=(x.size(0), x.size(0)) + ) + + return F.relu(x + self.bn(out)) + + def message(self, x_i, x_j, edge_attr, index): + score = torch.sigmoid(self.bn_interaction(self.nonlinear_full(torch.cat((x_i, x_j, edge_attr), dim=1)))) + return score * self.nonlinear(torch.cat((x_i, x_j, edge_attr), dim=1)) + + + + +############ +# 03/08/2023 +class MatformerConv(MessagePassing): + _alpha: OptTensor + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super(MatformerConv, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + self._alpha = None + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + else: + self.lin_edge = self.register_parameter('lin_edge', None) + + if concat: + self.lin_skip = nn.Linear(in_channels[1], out_channels, + bias=bias) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + if self.beta: + self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + else: + self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_msg_update = nn.Linear(out_channels * 3, out_channels * 3) + self.layer_norm = nn.LayerNorm(out_channels * 3) + self.msg_layer = nn.Sequential(nn.Linear(out_channels * 3, out_channels), nn.LayerNorm(out_channels)) + # simpler version + # self.lin_msg_update = nn.Linear(out_channels * 3, out_channels) + # self.layer_norm = nn.LayerNorm(out_channels) + # self.msg_layer = nn.Sequential(nn.Linear(out_channels, out_channels), nn.LayerNorm(out_channels)) + # self.msg_layer = nn.Linear(out_channels * 3, out_channels) + self.bn = nn.BatchNorm1d(out_channels) + # self.bn = nn.BatchNorm1d(out_channels * heads) + self.sigmoid = nn.Sigmoid() + self.reset_parameters() + + def reset_parameters(self): + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + if self.concat: + self.lin_concate.reset_parameters() + if self.edge_dim: + self.lin_edge.reset_parameters() + self.lin_skip.reset_parameters() + if self.beta: + self.lin_beta.reset_parameters() + + def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, + edge_attr: OptTensor = None, return_attention_weights=None): + + H, C = self.heads, self.out_channels + if isinstance(x, Tensor): + x: PairTensor = (x, x) + + query = self.lin_query(x[1]).view(-1, H, C) + key = self.lin_key(x[0]).view(-1, H, C) + value = self.lin_value(x[0]).view(-1, H, C) + + out = self.propagate(edge_index, query=query, key=key, value=value, + edge_attr=edge_attr, size=None) + alpha = self._alpha + self._alpha = None + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(dim=1) + + if self.concat: + out = self.lin_concate(out) + + out = F.silu(self.bn(out)) # after norm and silu + + if self.root_weight: + x_r = self.lin_skip(x[1]) + if self.lin_beta is not None: + beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) + beta = beta.sigmoid() + out = beta * x_r + (1 - beta) * out + else: + out += x_r + + + if isinstance(return_attention_weights, bool): + assert alpha is not None + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + elif isinstance(edge_index, SparseTensor): + return out, edge_index.set_value(alpha, layout='coo') + else: + return out + + def message(self, query_i: Tensor, key_i: Tensor, key_j: Tensor, value_j: Tensor, value_i: Tensor, + edge_attr: OptTensor, index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: + + if self.lin_edge is not None: + assert edge_attr is not None + edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels) + + query_i = torch.cat((query_i, query_i, query_i), dim=-1) + key_j = torch.cat((key_i, key_j, edge_attr), dim=-1) + alpha = (query_i * key_j) / math.sqrt(self.out_channels * 3) + self._alpha = alpha + alpha = F.dropout(alpha, p=self.dropout, training=self.training) + out = torch.cat((value_i, value_j, edge_attr), dim=-1) + out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, 3 * self.out_channels))) + out = self.msg_layer(out) + + # version two, simpler + # query_i = query_i + # key_j = key_j + # alpha = (query_i * key_j) / math.sqrt(self.out_channels) + # self._alpha = alpha + # out = torch.cat((value_i, value_j, edge_attr), dim=-1) + # out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, self.out_channels))) + # out = self.msg_layer(out) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') + + +class MatformerConv_edge(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + else: + self.lin_edge = self.register_parameter('lin_edge', None) + + if concat: + self.lin_skip = nn.Linear(in_channels[1], out_channels, + bias=bias) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + if self.beta: + self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + else: + self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter('lin_beta', None) + self.lin_msg_update = nn.Linear(out_channels * 3, out_channels * 3) + self.layer_norm = nn.LayerNorm(out_channels * 3) + self.msg_layer = nn.Sequential(nn.Linear(out_channels * 3, out_channels), nn.LayerNorm(out_channels)) + self.bn = nn.BatchNorm1d(out_channels) + self.sigmoid = nn.Sigmoid() + self.reset_parameters() + + def reset_parameters(self): + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + if self.concat: + self.lin_concate.reset_parameters() + if self.edge_dim: + self.lin_edge.reset_parameters() + self.lin_skip.reset_parameters() + if self.beta: + self.lin_beta.reset_parameters() + + def forward(self, edge: Union[Tensor, PairTensor], edge_nei_len: OptTensor = None, edge_nei_angle: OptTensor = None): + # preprocess for edge of shape [num_edges, hidden_dim] + + H, C = self.heads, self.out_channels + if isinstance(edge, Tensor): + edge: PairTensor = (edge, edge) + + query_x = self.lin_query(edge[1]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + key_x = self.lin_key(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + value_x = self.lin_value(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + + # preprocess for edge_neighbor of shape [num_edges, 3, hidden_dim] + query_y = self.lin_query(edge_nei_len).view(-1, 3, H, C) + key_y = self.lin_key(edge_nei_len).view(-1, 3, H, C) + value_y = self.lin_value(edge_nei_len).view(-1, 3, H, C) + + # preprocess for interaction of shape [num_edges, 3, hidden_dim] + edge_xy = self.lin_edge(edge_nei_angle).view(-1, 3, H, C) + + query = torch.cat((query_x, query_x, query_x), dim=-1) + key = torch.cat((key_x, key_y, edge_xy), dim=-1) + alpha = (query * key) / math.sqrt(self.out_channels * 3) + out = torch.cat((value_x, value_y, edge_xy), dim=-1) + out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha)) + out = self.msg_layer(out) + + if self.concat: + out = out.view(-1, 3, self.heads * self.out_channels) + else: + out = out.mean(dim=2) + + if self.concat: + out = self.lin_concate(out) + + # aggregate the msg + out = out.sum(dim=1) + + out = F.silu(self.bn(out)) + + if self.root_weight: + x_r = self.lin_skip(edge[1]) + out += x_r + + return out + + +##################### +# 03/07/2023 +##################### + + +# class MatformerConv_edge(MessagePassing): +# _alpha: OptTensor + +# def __init__( +# self, +# in_channels: Union[int, Tuple[int, int]], +# out_channels: int, +# heads: int = 1, +# concat: bool = True, +# beta: bool = False, +# dropout: float = 0.0, +# edge_dim: Optional[int] = None, +# bias: bool = True, +# root_weight: bool = True, +# **kwargs, +# ): +# kwargs.setdefault('aggr', 'add') +# super(MatformerConv_edge, self).__init__(node_dim=0, **kwargs) + +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.heads = heads +# self.beta = beta and root_weight +# self.root_weight = root_weight +# self.concat = concat +# self.dropout = dropout +# self.edge_dim = edge_dim +# self._alpha = None + +# if isinstance(in_channels, int): +# in_channels = (in_channels, in_channels) + +# self.lin_key = nn.Linear(in_channels[0], heads * out_channels) +# self.lin_query = nn.Linear(in_channels[1], heads * out_channels) +# self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + +# if edge_dim is not None: +# self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) +# else: +# self.lin_edge = self.register_parameter('lin_edge', None) + +# if concat: +# self.lin_skip = nn.Linear(in_channels[1], out_channels, +# bias=bias) +# self.lin_concate = nn.Linear(heads * out_channels, out_channels) +# if self.beta: +# self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=False) +# else: +# self.lin_beta = self.register_parameter('lin_beta', None) +# else: +# self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) +# if self.beta: +# self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) +# else: +# self.lin_beta = self.register_parameter('lin_beta', None) +# self.lin_msg_update = nn.Linear(out_channels * 3, out_channels * 3) +# self.layer_norm = nn.LayerNorm(out_channels * 3) +# self.msg_layer = nn.Sequential(nn.Linear(out_channels * 3, out_channels), nn.LayerNorm(out_channels)) +# # simpler version +# # self.lin_msg_update = nn.Linear(out_channels * 3, out_channels) +# # self.layer_norm = nn.LayerNorm(out_channels) +# # self.msg_layer = nn.Sequential(nn.Linear(out_channels, out_channels), nn.LayerNorm(out_channels)) +# # self.msg_layer = nn.Linear(out_channels * 3, out_channels) +# self.bn = nn.BatchNorm1d(out_channels) +# # self.bn = nn.BatchNorm1d(out_channels * heads) +# self.sigmoid = nn.Sigmoid() +# self.reset_parameters() + +# def reset_parameters(self): +# self.lin_key.reset_parameters() +# self.lin_query.reset_parameters() +# self.lin_value.reset_parameters() +# if self.concat: +# self.lin_concate.reset_parameters() +# if self.edge_dim: +# self.lin_edge.reset_parameters() +# self.lin_skip.reset_parameters() +# if self.beta: +# self.lin_beta.reset_parameters() + +# def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, +# edge_attr: OptTensor = None, return_attention_weights=None): + +# H, C = self.heads, self.out_channels +# if isinstance(x, Tensor): +# x: PairTensor = (x, x) + +# query = self.lin_query(x[1]).view(-1, H, C) +# key = self.lin_key(x[0]).view(-1, H, C) +# value = self.lin_value(x[0]).view(-1, H, C) + +# out = self.propagate(edge_index, query=query, key=key, value=value, +# edge_attr=edge_attr, size=None) +# alpha = self._alpha +# self._alpha = None + +# if self.concat: +# out = out.view(-1, self.heads * self.out_channels) +# else: +# out = out.mean(dim=1) + +# if self.concat: +# out = self.lin_concate(out) + +# out = F.silu(self.bn(out)) # after norm and silu + +# if self.root_weight: +# x_r = self.lin_skip(x[1]) +# if self.lin_beta is not None: +# beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) +# beta = beta.sigmoid() +# out = beta * x_r + (1 - beta) * out +# else: +# out += x_r + + +# if isinstance(return_attention_weights, bool): +# assert alpha is not None +# if isinstance(edge_index, Tensor): +# return out, (edge_index, alpha) +# elif isinstance(edge_index, SparseTensor): +# return out, edge_index.set_value(alpha, layout='coo') +# else: +# return out + +# def message(self, query_i: Tensor, key_i: Tensor, key_j: Tensor, value_j: Tensor, value_i: Tensor, +# edge_attr: OptTensor, index: Tensor, ptr: OptTensor, +# size_i: Optional[int]) -> Tensor: + +# if self.lin_edge is not None: +# assert edge_attr is not None +# edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels) + +# query_i = torch.cat((query_i, query_i, query_i), dim=-1) +# key_j = torch.cat((key_i, key_j, edge_attr), dim=-1) +# alpha = (query_i * key_j) / math.sqrt(self.out_channels * 3) +# self._alpha = alpha +# alpha = F.dropout(alpha, p=self.dropout, training=self.training) +# out = torch.cat((value_i, value_j, edge_attr), dim=-1) +# out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, 3 * self.out_channels))) +# out = self.msg_layer(out) + +# # version two, simpler +# # query_i = query_i +# # key_j = key_j +# # alpha = (query_i * key_j) / math.sqrt(self.out_channels) +# # self._alpha = alpha +# # out = torch.cat((value_i, value_j, edge_attr), dim=-1) +# # out = self.lin_msg_update(out) * self.sigmoid(self.layer_norm(alpha.view(-1, self.heads, self.out_channels))) +# # out = self.msg_layer(out) +# return out + +# def __repr__(self) -> str: +# return (f'{self.__class__.__name__}({self.in_channels}, ' +# f'{self.out_channels}, heads={self.heads})') + + + +##################### +# 03/21/2023 +##################### + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 64, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + if use_second_order_repr: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{out_channels}x0e' + ] + else: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o', + f'{out_channels}x0e', + ] + self.ns = ns + self.nv = nv + self.node_linear = nn.Linear(in_channels, ns) + self.skip_linear = nn.Linear(in_channels, out_channels) + self.v1_v2_linear = nn.Linear(ns, out_channels) + + self.sh = '1x0e + 1x1o + 1x2e' + self.v2_tp = v2_tp = o3.FullyConnectedTensorProduct(f'{ns}x0e + {nv}x1o + {nv}x2e', '1x0e + 1x1o + 1x2e', f'{out_channels}x0e', shared_weights=False) + self.v2_fc = nn.Sequential( + nn.Linear(edge_dim * 3, edge_dim), + nn.Softplus(), + nn.Linear(edge_dim, v2_tp.weight_numel) + ) + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + # MACE + self.softplus = nn.Softplus() + self.ln_0e = nn.Parameter(torch.ones(1, 3, 1)) + self.ln_1o = nn.Parameter(torch.ones(1, 3, 1)) + self.ln_2e = nn.Parameter(torch.ones(1, 3, 1)) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + lat_len: OptTensor = None): + edge_vec = data.edge_attr + n_ = node_feature.shape[0] + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + lat = o3.spherical_harmonics(self.sh, data.atom_lat.view(n_ * 3, 3), normalize=True, normalization='component') + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + irreps = o3.Irreps('1x0e + 1x1o + 1x2e') + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + # the second layer + lat_l0, lat_l1o, lat_l2e = lat[:, irreps.slices()[0]], lat[:, irreps.slices()[1]], lat[:, irreps.slices()[2]] + lat_l0 = (lat_l0.view(n_, 3, 1) * self.ln_0e).sum(dim=1) + lat_l1o = (lat_l1o.view(n_, 3, 3) * self.ln_1o).sum(dim=1) + lat_l2e = (lat_l2e.view(n_, 3, 5) * self.ln_2e).sum(dim=1) + lat_vec = torch.cat((lat_l0, lat_l1o, lat_l2e), dim=-1) + node_v2 = self.v2_tp(node_feature, lat_vec, self.v2_fc(lat_len.view(n_, -1))) + node_v2 = self.softplus(self.bn(node_v2)) + node_v2 += self.skip_linear(skip_connect) + + return node_v2 + + + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 128, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + if use_second_order_repr: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{out_channels}x0e' + ] + else: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o', + f'{out_channels}x0e', + ] + self.ns = ns + self.nv = nv + # for input x mapping + self.node_linear = nn.Linear(in_channels, ns) + # for input x mapping to the output + self.skip_linear = nn.Linear(in_channels, out_channels) + # for l0 mapping to the output + self.v1_v2_linear = nn.Linear(ns, out_channels) + + self.sh = '1x0e + 1x1o + 1x2e' + self.v2_tp = v2_tp = o3.FullyConnectedTensorProduct(f'{ns}x0e + {nv}x1o + {nv}x2e', '1x0e + 1x1o + 1x2e', f'{ns}x0e + {nv}x1o + {nv}x2e', shared_weights=False) + self.v2_fc = nn.Sequential( + nn.Linear(ns, ns), + nn.Softplus(), + nn.Linear(ns, v2_tp.weight_numel) + ) + + self.v2_tp_2 = v2_tp_2 = o3.FullyConnectedTensorProduct(f'{ns}x0e + {nv}x1o + {nv}x2e', '1x0e + 1x1o + 1x2e', f'{out_channels}x0e', shared_weights=False) + self.v2_fc_2 = nn.Sequential( + nn.Linear(ns, ns), + nn.Softplus(), + nn.Linear(ns, v2_tp_2.weight_numel) + ) + + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + # MACE + self.softplus = nn.Softplus() + self.ln_0e = nn.Parameter(torch.ones(1, ns)) + self.ln_1o = nn.Parameter(torch.ones(1, nv, 1)) + self.ln_2e = nn.Parameter(torch.ones(1, nv, 1)) + self.bn = nn.BatchNorm1d(ns) + + self.ln_0e2 = nn.Parameter(torch.ones(1, ns)) + self.ln_1o2 = nn.Parameter(torch.ones(1, nv, 1)) + self.ln_2e2 = nn.Parameter(torch.ones(1, nv, 1)) + self.bn2 = nn.BatchNorm1d(out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + edge_nei_len: OptTensor = None): + edge_vec = data.edge_attr + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + n_ = node_feature.shape[0] + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + ns, nv = self.ns, self.nv + irreps = o3.Irreps(f'{ns}x0e + {nv}x1o + {nv}x2e') + # the first layer + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + # the second layer + node_l0, node_l1o, node_l2e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + node_l1o, node_l2e = node_l1o.reshape(n_, -1, 3), node_l2e.reshape(n_, -1, 5) + node_l0 = self.softplus(node_l0) + node_l0_update = (node_l0 * self.ln_0e).sum(dim=1, keepdim=True) + node_l1o = (node_l1o * self.ln_1o).sum(dim=1, keepdim=True) + node_l2e = (node_l2e * self.ln_2e).sum(dim=1, keepdim=True) + node_feature_vec = torch.cat((node_l0_update, node_l1o.reshape(n_, -1), node_l2e.reshape(n_, -1)), dim=-1) + node_v2 = self.v2_tp(node_feature, node_feature_vec, self.v2_fc(node_l0)) + node_v2_l0 = node_v2[:, irreps.slices()[0]] + node_v2_l0 = node_v2_l0 + node_l0 + node_v2_l0 = self.softplus(self.bn(node_v2_l0)) + node_v2[:, irreps.slices()[0]] = node_v2_l0 + # the first layer + node_feature = self.nlayer_2(node_v2, edge_index, edge_feature, edge_irr) + # the second layer + node_l0, node_l1o, node_l2e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + node_l1o, node_l2e = node_l1o.reshape(n_, -1, 3), node_l2e.reshape(n_, -1, 5) + node_l0 = self.softplus(node_l0) + node_l0_update = (node_l0 * self.ln_0e2).sum(dim=1, keepdim=True) + node_l1o = (node_l1o * self.ln_1o2).sum(dim=1, keepdim=True) + node_l2e = (node_l2e * self.ln_2e2).sum(dim=1, keepdim=True) + node_feature_vec = torch.cat((node_l0_update, node_l1o.reshape(n_, -1), node_l2e.reshape(n_, -1)), dim=-1) + node_v2 = self.v2_tp_2(node_feature, node_feature_vec, self.v2_fc_2(node_l0)) + node_v2 = node_v2 + self.v1_v2_linear(node_l0) + node_v2 = self.softplus(self.bn2(node_v2)) + + node_v2 += self.skip_linear(skip_connect) + return node_v2 + + + + + # edge_nei_vec = data.edge_nei / data.edge_nei.norm(dim=-1, keepdim=True) + # edge_irr = torch.cat((self.edge_tp(edge_vec.unsqueeze(1).repeat(1, 3, 1), edge_nei_vec, self.edge_tp_fc(edge_nei_len)).sum(dim=1), + # edge_vec), dim=-1) + +# nonlinearity and norm of equi features + # ns, nv = self.ns, self.nv + # irreps = o3.Irreps(f'{ns}x0e + {nv}x1o + {nv}x2e') + # node_l0, node_l1o, node_l1e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + # node_l1o, node_l1e = node_l1o.reshape(n_, -1, 3), node_l1e.reshape(n_, -1, 3) + # # for order = 0 part + # node_l0 = self.softplus(node_l0) + # rms_l0 = node_l0.norm(dim=-1, keepdim=True) * (ns ** -0.5) + # node_l0 = node_l0 / rms_l0.clamp(min = 1e-12) * self.ln_0e + # # for order = 1o part + # l2norm = node_l1o.norm(dim=-1, keepdim=True) + # rms_l1o = l2norm.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l1o = node_l1o / rms_l1o.clamp(min = 1e-12) * self.ln_1o + # # for order = 1e part + # l2norme = node_l1e.norm(dim=-1, keepdim=True) + # rms_l1e = l2norme.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l1e = node_l1e / rms_l1e.clamp(min = 1e-12) * self.ln_1e + # node_feature = torch.cat((node_l0, node_l1o.reshape(n_, -1), node_l1e.reshape(n_, -1)), dim=-1) + # the second layer + # node_feature = self.nlayer_2(node_feature, edge_index, edge_feature, edge_irr) + + + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 128, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + # if use_second_order_repr: + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{out_channels}x0e' + ] + # else: + # irrep_seq = [ + # f'{ns}x0e', + # f'{ns}x0e + {nv}x1o', + # f'{out_channels}x0e', + # ] + + self.node_linear = nn.Linear(in_channels, ns) + # for input x mapping to the output + self.skip_linear = nn.Linear(in_channels, out_channels) + + self.sh = '1x0e + 1x1o + 1x2e' + + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[2], + n_edge_features=edge_dim, + residual=False + ) + + self.softplus = nn.Softplus() + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + edge_nei_len: OptTensor = None): + edge_vec = data.edge_attr + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + n_ = node_feature.shape[0] + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.nlayer_2(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.softplus(self.bn(node_feature)) + node_feature += self.skip_linear(skip_connect) + + return node_feature + + + + + # ns, nv = self.ns, self.nv + # irreps = o3.Irreps(f'{ns}x0e + {nv}x1o + {nv}x2e') + # node_l0, node_l1o, node_l2e = node_feature[:, irreps.slices()[0]], node_feature[:, irreps.slices()[1]], node_feature[:, irreps.slices()[2]] + # node_l1o, node_l2e = node_l1o.reshape(n_, -1, 3), node_l2e.reshape(n_, -1, 5) + # # for order = 0 part + # node_l0 = self.softplus(node_l0) + # rms_l0 = node_l0.norm(dim=-1, keepdim=True) * (ns ** -0.5) + # node_l0 = node_l0 / rms_l0.clamp(min = 1e-12) * self.ln_0e + # # for order = 1o part + # l2norm = node_l1o.norm(dim=-1, keepdim=True) + # rms_l1o = l2norm.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l1o = node_l1o / rms_l1o.clamp(min = 1e-12) * self.ln_1o + # # for order = 1e part + # l2norme = node_l2e.norm(dim=-1, keepdim=True) + # rms_l2e = l2norme.norm(dim=-2, keepdim=True) * (nv ** -0.5) + # node_l2e = node_l2e / rms_l2e.clamp(min = 1e-12) * self.ln_2e + # node_feature = torch.cat((node_l0, node_l1o.reshape(n_, -1), node_l2e.reshape(n_, -1)), dim=-1) \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_iComFormer/models/bn_utils.py b/benchmarks/matbench_v0.1_iComFormer/models/bn_utils.py new file mode 100644 index 00000000..eec8f27e --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/models/bn_utils.py @@ -0,0 +1,269 @@ +from typing import Optional, Any + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer + +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm +from torch.nn.modules.lazy import LazyModuleMixin +from torch.nn.modules.module import Module + + +class _NormBase(Module): + """Common base of _InstanceNorm and _BatchNorm""" + + _version = 2 + __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] + num_features: int + eps: float + momentum: float + affine: bool + track_running_stats: bool + # WARNING: weight and bias purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(_NormBase, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) + self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) + self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) + self.running_mean: Optional[Tensor] + self.running_var: Optional[Tensor] + self.register_buffer('num_batches_tracked', + torch.tensor(0, dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})) + self.num_batches_tracked: Optional[Tensor] + else: + self.register_buffer("running_mean", None) + self.register_buffer("running_var", None) + self.register_buffer("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[union-attr] + self.running_var.fill_(1) # type: ignore[union-attr] + self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] + + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def _check_input_dim(self, input): + raise NotImplementedError + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) + + super(_NormBase, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class _MaskedBatchNorm(_NormBase): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super(_MaskedBatchNorm, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + + def forward(self, input: Tensor, mask: Tensor) -> Tensor: + self._check_input_dim(input) + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + result, self.running_mean, self.running_var = batch_norm( + X=input, + running_mean=self.running_mean + if not self.training or self.track_running_stats + else None, + running_var=self.running_var if not self.training or self.track_running_stats else None, + weight=self.weight, + bias=self.bias, + training=bn_training, + momentum=exponential_average_factor, + eps=self.eps, + mask=mask, + ) + return result + +def batch_norm(X, weight, bias, running_mean, running_var, training, momentum, eps, mask): + if not training: + X_hat = (X - running_mean) / torch.sqrt(running_var + eps) + else: + count = mask.sum(dim=0, keepdim=True) + mean = (X * mask).sum(dim=0, keepdim=True) / (count + 1e-5) + var = (((X - mean) ** 2) * mask).sum(dim=0, keepdim=True) / (count + 1e-5) + X_hat = (X - mean) / torch.sqrt(var + eps) + # Update the mean and variance using moving average + running_mean = momentum * running_mean + (1.0 - momentum) * mean + running_var = momentum * running_var + (1.0 - momentum) * var + Y = (weight * X_hat + bias) * mask # Scale and shift + return Y, running_mean.data, running_var.data + + + +class MaskedBatchNorm1d(_MaskedBatchNorm): + r"""Applies Batch Normalization over a 2D or 3D input as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `__ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input). By default, the + elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The + standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. + + Also by default, during training this layer keeps running estimates of its + computed mean and variance, which are then used for normalization during + evaluation. The running estimates are kept with a default :attr:`momentum` + of 0.1. + + If :attr:`track_running_stats` is set to ``False``, this layer then does not + keep running estimates, and batch statistics are instead used during + evaluation time as well. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, + :math:`C` is the number of features or channels, and :math:`L` is the sequence length + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples:: + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm1d(100, affine=False) + >>> input = torch.randn(20, 100) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError( + "expected 2D or 3D input (got {}D input)".format(input.dim()) + ) + diff --git a/benchmarks/matbench_v0.1_iComFormer/models/pyg_att.py b/benchmarks/matbench_v0.1_iComFormer/models/pyg_att.py new file mode 100644 index 00000000..0ac69300 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/models/pyg_att.py @@ -0,0 +1,263 @@ +"""Implementation based on the template of ALIGNN.""" + +from typing import Tuple +import math +import numpy as np +import torch +import torch.nn.functional as F +from pydantic.typing import Literal +from torch import nn +from matformer.models.utils import RBFExpansion +from matformer.utils import BaseSettings +from matformer.features import angle_emb_mp +from torch_scatter import scatter +from matformer.models.transformer import MatformerConv, MatformerConv_edge, MatformerConvEqui + + +class MatformerConfig(BaseSettings): + """Hyperparameter schema for jarvisdgl.models.cgcnn.""" + + name: Literal["matformer"] + conv_layers: int = 3 + edge_layers: int = 1 + atom_input_features: int = 92 + edge_features: int = 256 + triplet_input_features: int = 256 + node_features: int = 256 + fc_layers: int = 1 + fc_features: int = 256 + output_features: int = 1 + node_layer_head: int = 1 + edge_layer_head: int = 1 + nn_based: bool = False + + link: Literal["identity", "log", "logit"] = "identity" + zero_inflated: bool = False + use_angle: bool = False + angle_lattice: bool = False + classification: bool = False + + class Config: + """Configure model settings behavior.""" + + env_prefix = "jv_model" + + +def bond_cosine(r1, r2): + bond_cosine = torch.sum(r1 * r2, dim=-1) / ( + torch.norm(r1, dim=-1) * torch.norm(r2, dim=-1) + ) + bond_cosine = torch.clamp(bond_cosine, -1, 1) + return bond_cosine + +class MatformerEquivariant(nn.Module): + """att pyg implementation.""" + + def __init__(self, config: MatformerConfig = MatformerConfig(name="matformer")): + """Set up att modules.""" + super().__init__() + print("Using equivariant marformer !!!!!!!!!!!!!!!!!!!!!!!!") + self.classification = config.classification + self.use_angle = config.use_angle + self.atom_embedding = nn.Linear( + config.atom_input_features, config.node_features + ) + self.rbf = nn.Sequential( + RBFExpansion( + vmin=-4.0, + vmax=0.0, + bins=config.edge_features, + ), + nn.Linear(config.edge_features, config.node_features), + nn.Softplus(), + # nn.Linear(config.node_features, config.node_features), + ) + + self.rbf_angle = nn.Sequential( + RBFExpansion( + vmin=-1.0, + vmax=1.0, + bins=config.triplet_input_features, + ), + nn.Linear(config.triplet_input_features, config.node_features), + nn.Softplus(), + # nn.Linear(config.node_features, config.node_features), + ) + + self.att_layers = nn.ModuleList( + [ + MatformerConv(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + for _ in range(config.conv_layers) + ] + ) + + self.edge_update_layer = MatformerConv_edge(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + + self.equi_update = MatformerConvEqui(in_channels=config.node_features, out_channels=config.node_features, edge_dim=config.node_features, use_second_order_repr=True) + + self.fc = nn.Sequential( + nn.Linear(config.node_features, config.fc_features), nn.SiLU() + ) + self.sigmoid = nn.Sigmoid() + + if self.classification: + self.fc_out = nn.Linear(config.fc_features, 2) + self.softmax = nn.LogSoftmax(dim=1) + else: + self.fc_out = nn.Linear( + config.fc_features, config.output_features + ) + + self.link = None + self.link_name = config.link + if config.link == "identity": + self.link = lambda x: x + elif config.link == "log": + self.link = torch.exp + avg_gap = 0.7 # magic number -- average bandgap in dft_3d + if not self.zero_inflated: + self.fc_out.bias.data = torch.tensor( + np.log(avg_gap), dtype=torch.float + ) + elif config.link == "logit": + self.link = torch.sigmoid + + def forward(self, data) -> torch.Tensor: + data, ldata, lattice = data + node_features = self.atom_embedding(data.x) + n_nodes = node_features.shape[0] + edge_feat = -0.75 / torch.norm(data.edge_attr, dim=1) + # lat_feat = -0.75 / torch.norm(data.atom_lat.view(n_nodes * 3, 3), dim=1) + # edge_nei_len = -0.75 / torch.norm(data.edge_nei, dim=-1) # [num_edges, 3] + # edge_nei_angle = bond_cosine(data.edge_nei, data.edge_attr.unsqueeze(1).repeat(1, 3, 1)) # [num_edges, 3, 3] -> [num_edges, 3] + num_edge = edge_feat.shape[0] + edge_features = self.rbf(edge_feat) + # lat_features = self.rbf(lat_feat).view(n_nodes, 3, -1) + # edge_nei_len = self.rbf(edge_nei_len.view(-1)).view(num_edge, 3, -1) + # edge_nei_angle = self.rbf_angle(edge_nei_angle.view(-1)).view(num_edge, 3, -1) + + node_features = self.att_layers[0](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # edge_features = self.edge_update_layer(edge_features, edge_nei_len, edge_nei_angle) # / math.sqrt(4) + # node_features = self.att_layers[1](node_features, data.edge_index, edge_features) # / math.sqrt(16) + node_features = self.equi_update(data, node_features, data.edge_index, edge_features) + node_features = self.att_layers[2](node_features, data.edge_index, edge_features) # / math.sqrt(16) + + # crystal-level readout + features = scatter(node_features, data.batch, dim=0, reduce="mean") + + # features = F.softplus(features) + features = self.fc(features) + + out = self.fc_out(features) + if self.link: + out = self.link(out) + if self.classification: + out = self.softmax(out) + + return torch.squeeze(out) + + + + +class MatformerInvariant(nn.Module): + """att pyg implementation.""" + + def __init__(self, config: MatformerConfig = MatformerConfig(name="matformer")): + """Set up att modules.""" + super().__init__() + print("Using invariant marformer !!!!!!!!!!!!!!!!!!!!!!!!") + self.classification = config.classification + self.use_angle = config.use_angle + self.atom_embedding = nn.Linear( + config.atom_input_features, config.node_features + ) + self.rbf = nn.Sequential( + RBFExpansion( + vmin=-4.0, + vmax=0.0, + bins=config.edge_features, + ), + nn.Linear(config.edge_features, config.node_features), + nn.Softplus(), + ) + + self.rbf_angle = nn.Sequential( + RBFExpansion( + vmin=-1.0, + vmax=1.0, + bins=config.triplet_input_features, + ), + nn.Linear(config.triplet_input_features, config.node_features), + nn.Softplus(), + ) + + self.att_layers = nn.ModuleList( + [ + MatformerConv(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + for _ in range(config.conv_layers) + ] + ) + + self.edge_update_layer = MatformerConv_edge(in_channels=config.node_features, out_channels=config.node_features, heads=config.node_layer_head, edge_dim=config.node_features) + + self.fc = nn.Sequential( + nn.Linear(config.node_features, config.fc_features), nn.SiLU() + ) + self.sigmoid = nn.Sigmoid() + + if self.classification: + self.fc_out = nn.Linear(config.fc_features, 2) + self.softmax = nn.LogSoftmax(dim=1) + else: + self.fc_out = nn.Linear( + config.fc_features, config.output_features + ) + + self.link = None + self.link_name = config.link + if config.link == "identity": + self.link = lambda x: x + elif config.link == "log": + self.link = torch.exp + avg_gap = 0.7 # magic number -- average bandgap in dft_3d + if not self.zero_inflated: + self.fc_out.bias.data = torch.tensor( + np.log(avg_gap), dtype=torch.float + ) + elif config.link == "logit": + self.link = torch.sigmoid + + def forward(self, data) -> torch.Tensor: + data, ldata, lattice = data + node_features = self.atom_embedding(data.x) + edge_feat = -0.75 / torch.norm(data.edge_attr, dim=1) # [num_edges] + edge_nei_len = -0.75 / torch.norm(data.edge_nei, dim=-1) # [num_edges, 3] + edge_nei_angle = bond_cosine(data.edge_nei, data.edge_attr.unsqueeze(1).repeat(1, 3, 1)) # [num_edges, 3, 3] -> [num_edges, 3] + num_edge = edge_feat.shape[0] + edge_features = self.rbf(edge_feat) + edge_nei_len = self.rbf(edge_nei_len.reshape(-1)).reshape(num_edge, 3, -1) + edge_nei_angle = self.rbf_angle(edge_nei_angle.reshape(-1)).reshape(num_edge, 3, -1) + + node_features = self.att_layers[0](node_features, data.edge_index, edge_features) # / math.sqrt(16) + edge_features = self.edge_update_layer(edge_features, edge_nei_len, edge_nei_angle) # / math.sqrt(4) + node_features = self.att_layers[1](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # edge_features = self.edge_update_layers[1](edge_features, ldata.edge_index, angle_features) + node_features = self.att_layers[2](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # node_features = self.att_layers[3](node_features, data.edge_index, edge_features) # / math.sqrt(16) + # node_features = self.att_layers[4](node_features, data.edge_index, edge_features) # / math.sqrt(16) + + # crystal-level readout + features = scatter(node_features, data.batch, dim=0, reduce="mean") + + # features = F.softplus(features) + features = self.fc(features) + + out = self.fc_out(features) + if self.link: + out = self.link(out) + if self.classification: + out = self.softmax(out) + + return torch.squeeze(out) + + diff --git a/benchmarks/matbench_v0.1_iComFormer/models/transformer.py b/benchmarks/matbench_v0.1_iComFormer/models/transformer.py new file mode 100644 index 00000000..cf504556 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/models/transformer.py @@ -0,0 +1,282 @@ +import math +from e3nn import o3 +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch_sparse import SparseTensor +import torch.nn as nn + +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.typing import Adj, OptTensor, PairTensor +from matformer.models.utils import softmax +from torch_scatter import scatter + + +class MatformerConv(MessagePassing): + _alpha: OptTensor + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super(MatformerConv, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + self._alpha = None + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + self.lin_edge = nn.Linear(edge_dim, heads * out_channels) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + + self.lin_msg_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.softplus = nn.Softplus() + self.silu = nn.SiLU() + self.key_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.bn = nn.BatchNorm1d(out_channels) + self.bn_att = nn.BatchNorm1d(out_channels) + self.sigmoid = nn.Sigmoid() + print('I am using the correct version of matformer') + + def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, + edge_attr: OptTensor = None, return_attention_weights=None): + + H, C = self.heads, self.out_channels + if isinstance(x, Tensor): + x: PairTensor = (x, x) + + query = self.lin_query(x[1]).view(-1, H, C) + key = self.lin_key(x[0]).view(-1, H, C) + value = self.lin_value(x[0]).view(-1, H, C) + + out = self.propagate(edge_index, query=query, key=key, value=value, + edge_attr=edge_attr, size=None) + + out = out.view(-1, self.heads * self.out_channels) + out = self.lin_concate(out) + + return self.softplus(x[1] + self.bn(out)) + + def message(self, query_i: Tensor, key_i: Tensor, key_j: Tensor, value_j: Tensor, value_i: Tensor, + edge_attr: OptTensor, index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: + + edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels) + key_j = self.key_update(torch.cat((key_i, key_j, edge_attr), dim=-1)) + alpha = (query_i * key_j) / math.sqrt(self.out_channels) + out = self.lin_msg_update(torch.cat((value_i, value_j, edge_attr), dim=-1)) + out = out * self.sigmoid(self.bn_att(alpha.view(-1, self.out_channels)).view(-1, self.heads, self.out_channels)) + return out + + +class MatformerConv_edge(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + self.lemb = nn.Embedding(num_embeddings=3, embedding_dim=32) + self.embedding_dim = 32 + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + # for test + self.lin_key_e1 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_value_e1 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_key_e2 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_value_e2 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_key_e3 = nn.Linear(in_channels[0], heads * out_channels) + self.lin_value_e3 = nn.Linear(in_channels[0], heads * out_channels) + # for test ends + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + self.lin_edge_len = nn.Linear(in_channels[0] + self.embedding_dim, in_channels[0]) + self.lin_concate = nn.Linear(heads * out_channels, out_channels) + self.lin_msg_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.silu = nn.SiLU() + self.softplus = nn.Softplus() + self.key_update = nn.Sequential(nn.Linear(out_channels * 3, out_channels), + nn.SiLU(), + nn.Linear(out_channels, out_channels)) + self.bn_att = nn.BatchNorm1d(out_channels) + + self.bn = nn.BatchNorm1d(out_channels) + self.sigmoid = nn.Sigmoid() + print('I am using the invariant version of EPCNet') + + def forward(self, edge: Union[Tensor, PairTensor], edge_nei_len: OptTensor = None, edge_nei_angle: OptTensor = None): + # preprocess for edge of shape [num_edges, hidden_dim] + + H, C = self.heads, self.out_channels + if isinstance(edge, Tensor): + edge: PairTensor = (edge, edge) + device = edge[1].device + query_x = self.lin_query(edge[1]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + key_x = self.lin_key(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + value_x = self.lin_value(edge[0]).view(-1, H, C).unsqueeze(1).repeat(1, 3, 1, 1) + num_edge = query_x.shape[0] + # preprocess for edge_neighbor of shape [num_edges, 3, hidden_dim] + # lembs = torch.cat((self.lemb(torch.tensor([0]).to(device)), self.lemb(torch.tensor([1]).to(device)), self.lemb(torch.tensor([2]).to(device))), dim=0).unsqueeze(0).repeat(num_edge, 1, 1) + # edge_nei_len = self.lin_edge_len(torch.cat((edge_nei_len, lembs), dim=-1)) + # query_y = self.lin_query(edge_nei_len).view(-1, 3, H, C) + # key_y = self.lin_key(edge_nei_len).view(-1, 3, H, C) + # value_y = self.lin_value(edge_nei_len).view(-1, 3, H, C) + + # test begin + key_y = torch.cat((self.lin_key_e1(edge_nei_len[:,0,:]).view(-1, 1, H, C), + self.lin_key_e2(edge_nei_len[:,1,:]).view(-1, 1, H, C), + self.lin_key_e3(edge_nei_len[:,2,:]).view(-1, 1, H, C)), dim=1) + value_y = torch.cat((self.lin_value_e1(edge_nei_len[:,0,:]).view(-1, 1, H, C), + self.lin_value_e2(edge_nei_len[:,1,:]).view(-1, 1, H, C), + self.lin_value_e3(edge_nei_len[:,2,:]).view(-1, 1, H, C)), dim=1) + # test end + + # preprocess for interaction of shape [num_edges, 3, hidden_dim] + edge_xy = self.lin_edge(edge_nei_angle).view(-1, 3, H, C) + + key = self.key_update(torch.cat((key_x, key_y, edge_xy), dim=-1)) + alpha = (query_x * key) / math.sqrt(self.out_channels) + out = self.lin_msg_update(torch.cat((value_x, value_y, edge_xy), dim=-1)) + out = out * self.sigmoid(self.bn_att(alpha.view(-1, self.out_channels)).view(-1, 3, self.heads, self.out_channels)) + + out = out.view(-1, 3, self.heads * self.out_channels) + out = self.lin_concate(out) + # aggregate the msg + out = out.sum(dim=1) + + return self.softplus(edge[1] + self.bn(out)) + + + +class TensorProductConvLayer(torch.nn.Module): + # from Torsional diffusion + def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True): + super(TensorProductConvLayer, self).__init__() + self.in_irreps = in_irreps + self.out_irreps = out_irreps + self.sh_irreps = sh_irreps + self.residual = residual + + self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False) + + self.fc = nn.Sequential( + nn.Linear(n_edge_features, n_edge_features), + nn.Softplus(), + nn.Linear(n_edge_features, tp.weight_numel) + ) + + def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean'): + + edge_src, edge_dst = edge_index + tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr)) + + out_nodes = out_nodes or node_attr.shape[0] + out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce) + if self.residual: + padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1])) + out = out + padded + + return out + + +class MatformerConvEqui(nn.Module): + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + edge_dim: Optional[int] = None, + use_second_order_repr: bool = True, + ns: int = 128, + nv: int = 8, + residual: bool = True, + ): + super().__init__() + + irrep_seq = [ + f'{ns}x0e', + f'{ns}x0e + {nv}x1o + {nv}x2e', + f'{ns}x0e' + ] + self.ns, self.nv = ns, nv + self.node_linear = nn.Linear(in_channels, ns) + self.skip_linear = nn.Linear(in_channels, out_channels) + self.sh = '1x0e + 1x1o + 1x2e' + self.nlayer_1 = TensorProductConvLayer( + in_irreps=irrep_seq[0], + sh_irreps=self.sh, + out_irreps=irrep_seq[1], + n_edge_features=edge_dim, + residual=residual + ) + self.nlayer_2 = TensorProductConvLayer( + in_irreps=irrep_seq[1], + sh_irreps=self.sh, + out_irreps=irrep_seq[2], + n_edge_features=edge_dim, + residual=False + ) + self.softplus = nn.Softplus() + self.bn = nn.BatchNorm1d(ns) + self.node_linear_2 = nn.Linear(ns, out_channels) + + def forward(self, data, node_feature: Union[Tensor, PairTensor], edge_index: Adj, edge_feature: Union[Tensor, PairTensor], + edge_nei_len: OptTensor = None): + edge_vec = data.edge_attr + edge_irr = o3.spherical_harmonics(self.sh, edge_vec, normalize=True, normalization='component') + n_ = node_feature.shape[0] + skip_connect = node_feature + node_feature = self.node_linear(node_feature) + node_feature = self.nlayer_1(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.nlayer_2(node_feature, edge_index, edge_feature, edge_irr) + node_feature = self.softplus(self.node_linear_2(self.softplus(self.bn(node_feature)))) + node_feature += self.skip_linear(skip_connect) + + return node_feature \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_iComFormer/models/utils.py b/benchmarks/matbench_v0.1_iComFormer/models/utils.py new file mode 100644 index 00000000..aa01ef1b --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/models/utils.py @@ -0,0 +1,126 @@ +"""Shared model-building components.""" +from typing import Optional + +import numpy as np +import torch +from torch import nn + +from torch import Tensor +from torch_scatter import gather_csr, scatter, segment_csr + +from torch_geometric.utils.num_nodes import maybe_num_nodes + +class RBFExpansion(nn.Module): + """Expand interatomic distances with radial basis functions.""" + + def __init__( + self, + vmin: float = 0, + vmax: float = 8, + bins: int = 40, + lengthscale: Optional[float] = None, + ): + """Register torch parameters for RBF expansion.""" + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.register_buffer( + "centers", torch.linspace(self.vmin, self.vmax, self.bins) + ) + + if lengthscale is None: + # SchNet-style + # set lengthscales relative to granularity of RBF expansion + self.lengthscale = np.diff(self.centers).mean() + self.gamma = 1 / self.lengthscale + + else: + self.lengthscale = lengthscale + self.gamma = 1 / (lengthscale ** 2) + + def forward(self, distance: torch.Tensor) -> torch.Tensor: + """Apply RBF expansion to interatomic distance tensor.""" + return torch.exp( + -self.gamma * (distance.unsqueeze(1) - self.centers) ** 2 + ) + + +@torch.jit.script +def softmax(src: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, + dim: int = 0) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the softmax based on + sorted inputs in CSR representation. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + :rtype: :class:`Tensor` + """ + if ptr is not None: + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + ptr = ptr.view(size) + src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) + out = (src - src_max).exp() + out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src, index, dim, dim_size=N, reduce='max') + src_max = src_max.index_select(dim, index) + out = (src - src_max).exp() + out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + out_sum = out_sum.index_select(dim, index) + else: + raise NotImplementedError + + return out / (out_sum + 1e-16) + + +@torch.jit.script +def softmax_vec(src: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, + dim: int = 0) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the softmax based on + sorted inputs in CSR representation. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + :rtype: :class:`Tensor` + """ + if ptr is not None: + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + ptr = ptr.view(size) + src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr) + out = (src - src_max).exp() + out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src, index, dim, dim_size=N, reduce='max') + src_max = src_max.index_select(dim, index) + out = (src - src_max).exp() + out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + out_sum = out_sum.index_select(dim, index) + else: + raise NotImplementedError + + return out / (out_sum + 1e-16) \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_iComFormer/results.json.gz b/benchmarks/matbench_v0.1_iComFormer/results.json.gz new file mode 100644 index 00000000..e1eda9c8 Binary files /dev/null and b/benchmarks/matbench_v0.1_iComFormer/results.json.gz differ diff --git a/benchmarks/matbench_v0.1_iComFormer/scheduler.py b/benchmarks/matbench_v0.1_iComFormer/scheduler.py new file mode 100644 index 00000000..1c97817b --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/scheduler.py @@ -0,0 +1,244 @@ +import types +import math +import torch +from torch import inf +from functools import wraps +import warnings +import weakref +from collections import Counter +from bisect import bisect_right +from torch.optim import Optimizer + +class LRScheduler: + + def __init__(self, optimizer, last_epoch=-1, verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.verbose = verbose + + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr(self, is_verbose, group, lr, epoch=None): + """Display the current learning rate. + """ + if is_verbose: + if epoch is None: + print('Adjusting learning rate' + ' of group {} to {:.4e}.'.format(group, lr)) + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: adjusting learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) + + + def step(self, epoch=None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (int): The power of the polynomial. Default: 1.0. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False, end_lr=1e-5, decay_steps=10): + self.total_iters = total_iters + self.power = power + self.end_lr = end_lr + self.decay_steps = decay_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + # print(self.last_epoch) + # print(self._step_count) + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - step / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + # if self.last_epoch == 0 or self.last_epoch > self.total_iters: + # return [group["lr"] for group in self.optimizer.param_groups] + + # decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power + # return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + + print(self.last_epoch) + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - self.last_epoch / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # return [ + # ( + # base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power + # ) + # for base_lr in self.base_lrs + # ] + + # def decayed_learning_rate(step): + # step = min(step, decay_steps) + # return ((initial_learning_rate - end_learning_rate) * + # (1 - step / decay_steps) ^ (power) + # ) + end_learning_rate \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_iComFormer/train.py b/benchmarks/matbench_v0.1_iComFormer/train.py new file mode 100644 index 00000000..ff506fa4 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/train.py @@ -0,0 +1,913 @@ +from functools import partial + +# from pathlib import Path +from typing import Any, Dict, Union + +import ignite +import torch + +from ignite.contrib.handlers import TensorboardLogger +try: + from ignite.contrib.handlers.stores import EpochOutputStore +except Exception as exp: + from ignite.handlers.stores import EpochOutputStore + + pass +from ignite.handlers import EarlyStopping +from ignite.contrib.handlers.tensorboard_logger import ( + global_step_from_engine, +) +from ignite.contrib.handlers.tqdm_logger import ProgressBar +from ignite.engine import ( + Events, + create_supervised_evaluator, + create_supervised_trainer, +) +from ignite.contrib.metrics import ROC_AUC, RocCurve +from ignite.metrics import ( + Accuracy, + Precision, + Recall, + ConfusionMatrix, +) +import pickle as pk +import numpy as np +from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan +from ignite.metrics import Loss, MeanAbsoluteError +from torch import nn +from matformer import models +from matformer.data import get_train_val_loaders +from matformer.config import TrainingConfig +# from matformer.models.pyg_att import Matformer + +from jarvis.db.jsonutils import dumpjson +import json +import pprint + +import os + +# import sys +# sys.path.append("/mnt/data/shared/congfu/CompCrystal/NewModel_27/matformer/") +# from scheduler import PolynomialLR + + +import types +import math +import torch +from torch import inf +from functools import wraps +import warnings +import weakref +from collections import Counter +from bisect import bisect_right +from torch.optim import Optimizer + +class LRScheduler: + + def __init__(self, optimizer, last_epoch=-1, verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.verbose = verbose + + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr(self, is_verbose, group, lr, epoch=None): + """Display the current learning rate. + """ + if is_verbose: + if epoch is None: + print('Adjusting learning rate' + ' of group {} to {:.4e}.'.format(group, lr)) + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: adjusting learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) + + + def step(self, epoch=None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (int): The power of the polynomial. Default: 1.0. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False, end_lr=1e-5, decay_steps=10): + self.total_iters = total_iters + self.power = power + self.end_lr = end_lr + self.decay_steps = decay_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + # print(self.last_epoch) + # print(self._step_count) + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - step / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # if not self._get_lr_called_within_step: + # warnings.warn("To get the last learning rate computed by the scheduler, " + # "please use `get_last_lr()`.", UserWarning) + + # if self.last_epoch == 0 or self.last_epoch > self.total_iters: + # return [group["lr"] for group in self.optimizer.param_groups] + + # decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power + # return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + + print(self.last_epoch) + step = min(self.last_epoch, self.decay_steps) + return [ + ( + (base_lr - self.end_lr) * (1.0 - self.last_epoch / self.decay_steps) ** self.power + self.end_lr + ) + for base_lr in self.base_lrs + ] + + # return [ + # ( + # base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power + # ) + # for base_lr in self.base_lrs + # ] + + # def decayed_learning_rate(step): + # step = min(step, decay_steps) + # return ((initial_learning_rate - end_learning_rate) * + # (1 - step / decay_steps) ^ (power) + # ) + end_learning_rate + +########################################################################################### + + +# torch config +torch.set_default_dtype(torch.float32) + +device = "cpu" +if torch.cuda.is_available(): + device = torch.device("cuda") + + +def activated_output_transform(output): + """Exponentiate output.""" + y_pred, y = output + y_pred = torch.exp(y_pred) + y_pred = y_pred[:, 1] + return y_pred, y + + +def make_standard_scalar_and_pca(output): + """Use standard scalar and PCS for multi-output data.""" + sc = pk.load(open(os.path.join(tmp_output_dir, "sc.pkl"), "rb")) + y_pred, y = output + y_pred = torch.tensor(sc.transform(y_pred.cpu().numpy()), device=device) + y = torch.tensor(sc.transform(y.cpu().numpy()), device=device) + return y_pred, y + + +def thresholded_output_transform(output): + """Round off output.""" + y_pred, y = output + y_pred = torch.round(torch.exp(y_pred)) + # print ('output',y_pred) + return y_pred, y + + +def group_decay(model): + """Omit weight decay from bias and batchnorm params.""" + decay, no_decay = [], [] + + for name, p in model.named_parameters(): + if "bias" in name or "bn" in name or "norm" in name: + no_decay.append(p) + else: + decay.append(p) + + return [ + {"params": decay}, + {"params": no_decay, "weight_decay": 0}, + ] + + +def setup_optimizer(params, config: TrainingConfig): + """Set up optimizer for param groups.""" + if config.optimizer == "adamw": + optimizer = torch.optim.AdamW( + params, + lr=config.learning_rate, + weight_decay=config.weight_decay, + ) + elif config.optimizer == "sgd": + optimizer = torch.optim.SGD( + params, + lr=config.learning_rate, + momentum=0.9, + weight_decay=config.weight_decay, + ) + return optimizer + + +def train_dgl( + config: Union[TrainingConfig, Dict[str, Any]], + model: nn.Module = None, + train_val_test_loaders=[], + test_only=False, + use_save=True, + mp_id_list=None, + train_inputs=None, + train_outputs=None, + test_inputs=None, + test_outputs=None, + model_variant=None, +): + """ + `config` should conform to matformer.conf.TrainingConfig, and + if passed as a dict with matching keys, pydantic validation is used + """ + print(config) + if type(config) is dict: + try: + config = TrainingConfig(**config) + except Exception as exp: + print("Check", exp) + print('error in converting to training config!') + import os + + if not os.path.exists(config.output_dir): + os.makedirs(config.output_dir) + checkpoint_dir = os.path.join(config.output_dir) + deterministic = False + classification = False + print("config:") + tmp = config.dict() + f = open(os.path.join(config.output_dir, "config.json"), "w") + f.write(json.dumps(tmp, indent=4)) + f.close() + global tmp_output_dir + tmp_output_dir = config.output_dir + pprint.pprint(tmp) + if config.classification_threshold is not None: + classification = True + if config.random_seed is not None: + deterministic = True + ignite.utils.manual_seed(config.random_seed) + + # import pdb; pdb.set_trace() + line_graph = True + if not train_val_test_loaders: + # use input standardization for all real-valued feature sets + ( + train_loader, + val_loader, + test_loader, + prepare_batch, + mean_train, + std_train, + ) = get_train_val_loaders( + dataset=config.dataset, + target=config.target, + n_train=config.n_train, + n_val=config.n_val, + n_test=config.n_test, + train_ratio=config.train_ratio, + val_ratio=config.val_ratio, + test_ratio=config.test_ratio, + batch_size=config.batch_size, + atom_features=config.atom_features, + neighbor_strategy=config.neighbor_strategy, + standardize=config.atom_features != "cgcnn", + line_graph=line_graph, + id_tag=config.id_tag, + pin_memory=config.pin_memory, + workers=config.num_workers, + save_dataloader=config.save_dataloader, + use_canonize=config.use_canonize, + filename=config.filename, + cutoff=config.cutoff, + max_neighbors=config.max_neighbors, + output_features=config.model.output_features, + classification_threshold=config.classification_threshold, + target_multiplication_factor=config.target_multiplication_factor, + standard_scalar_and_pca=config.standard_scalar_and_pca, + keep_data_order=config.keep_data_order, + output_dir=config.output_dir, + matrix_input=config.matrix_input, + pyg_input=config.pyg_input, + use_lattice=config.use_lattice, + use_angle=config.use_angle, + use_save=use_save, + mp_id_list=mp_id_list, + train_inputs=train_inputs, + train_outputs=train_outputs, + test_inputs=test_inputs, + test_outputs=test_outputs, + ) + else: + train_loader = train_val_test_loaders[0] + val_loader = train_val_test_loaders[1] + test_loader = train_val_test_loaders[2] + prepare_batch = train_val_test_loaders[3] + prepare_batch = partial(prepare_batch, device=device) + if classification: + config.model.classification = True + # define network, optimizer, scheduler + if model_variant == 'matformerinvariant': + from matformer.models.pyg_att import MatformerInvariant as Matformer + elif model_variant == 'matformerequivariant': + from matformer.models.pyg_att import MatformerEquivariant as Matformer + _model = { + "matformer" : Matformer, + } + if std_train is None: + std_train = 1.0 + print('std train is none!') + print('std train:', std_train) + if model is None: + net = _model.get(config.model.name)(config.model) + print("config:") + pprint.pprint(config.model.dict()) + else: + net = model + + net.to(device) + if config.distributed: + import torch.distributed as dist + import os + + def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + def cleanup(): + dist.destroy_process_group() + + setup(2, 2) + net = torch.nn.parallel.DistributedDataParallel( + net + ) + params = group_decay(net) + optimizer = setup_optimizer(params, config) + + if config.scheduler == "none": + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lambda epoch: 1.0 + ) + + elif config.scheduler == "onecycle": + steps_per_epoch = len(train_loader) + pct_start = config.warmup_steps / (config.epochs * steps_per_epoch) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=config.learning_rate, + epochs=config.epochs, + steps_per_epoch=steps_per_epoch, + # pct_start=pct_start, + pct_start=0.3, + ) + elif config.scheduler == "step": + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=100000, + gamma=0.96, + ) + elif config.scheduler == "polynomial": + steps_per_epoch = len(train_loader) + num_steps = config.epochs * steps_per_epoch + scheduler = PolynomialLR( + optimizer, + decay_steps=num_steps, + end_lr=1e-5, + ) + + # select configured loss function + criteria = { + "mse": nn.MSELoss(), + "l1": nn.L1Loss(), + } + criterion = criteria[config.criterion] + # set up training engine and evaluators + metrics = {"loss": Loss(criterion), "mae": MeanAbsoluteError() * std_train, "neg_mae": -1.0 * MeanAbsoluteError() * std_train} + trainer = create_supervised_trainer( + net, + optimizer, + criterion, + prepare_batch=prepare_batch, + device=device, + deterministic=deterministic, + ) + evaluator = create_supervised_evaluator( + net, + metrics=metrics, + prepare_batch=prepare_batch, + device=device, + ) + train_evaluator = create_supervised_evaluator( + net, + metrics=metrics, + prepare_batch=prepare_batch, + device=device, + ) + if test_only: + checkpoint_tmp = torch.load('/your_model_path.pt') + to_load = { + "model": net, + "optimizer": optimizer, + "lr_scheduler": scheduler, + "trainer": trainer, + } + Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_tmp) + net.eval() + targets = [] + predictions = [] + import time + t1 = time.time() + with torch.no_grad(): + for dat in test_loader: + g, lg, _, target = dat + try: + out_data = net([g.to(device), lg.to(device), _.to(device)]) + success_flag=1 + except: # just in case + print('error for this data') + print(g) + success_flag=0 + if success_flag > 0: + out_data = out_data.cpu().numpy().tolist() + target = target.cpu().numpy().flatten().tolist() + if len(target) == 1: + target = target[0] + targets.append(target) + predictions.append(out_data) + t2 = time.time() + f.close() + from sklearn.metrics import mean_absolute_error + targets = np.array(targets) * std_train + predictions = np.array(predictions) * std_train + print("Test MAE:", mean_absolute_error(targets, predictions)) + print("Total test time:", t2-t1) + return mean_absolute_error(targets, predictions) + # ignite event handlers: + trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan()) + + # apply learning rate scheduler + trainer.add_event_handler( + Events.ITERATION_COMPLETED, lambda engine: scheduler.step() + ) + + # checkpoint_tmp = torch.load("/mnt/data/shared/congfu/CompCrystal/NewModel_27/matformer/scripts/matbench_mp_e_form_equivariant_max25_epoch500_lr1e-3_L1_fold1/checkpoint_299.pt") + # to_load = { + # "model": net, + # "optimizer": optimizer, + # "lr_scheduler": scheduler, + # "trainer": trainer, + # } + # Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_tmp) + # print('checkpoint.pt loaded') + # print('current epoch:', trainer.state.epoch) + # print('current optimizer:', optimizer) + # print('current scheduler:', scheduler) + + if config.write_checkpoint: + # model checkpointing + to_save = { + "model": net, + "optimizer": optimizer, + "lr_scheduler": scheduler, + "trainer": trainer, + } + handler = Checkpoint( + to_save, + DiskSaver(checkpoint_dir, create_dir=True, require_empty=False), + n_saved=2, + global_step_transform=lambda *_: trainer.state.epoch, + ) + trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) + # evaluate save + to_save = {"model": net} + handler = Checkpoint( + to_save, + DiskSaver(checkpoint_dir, create_dir=True, require_empty=False), + n_saved=5, + filename_prefix='best', + score_name="neg_mae", + global_step_transform=lambda *_: trainer.state.epoch, + ) + evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler) + if config.progress: + pbar = ProgressBar() + pbar.attach(trainer, output_transform=lambda x: {"loss": x}) + # pbar.attach(evaluator,output_transform=lambda x: {"mae": x}) + + history = { + "train": {m: [] for m in metrics.keys()}, + "validation": {m: [] for m in metrics.keys()}, + } + + if config.store_outputs: + # in history["EOS"] + eos = EpochOutputStore() + eos.attach(evaluator) + train_eos = EpochOutputStore() + train_eos.attach(train_evaluator) + + # collect evaluation performance + @trainer.on(Events.EPOCH_COMPLETED) + def log_results(engine): + """Print training and validation metrics to console.""" + # train_evaluator.run(train_loader) + # evaluator.run(val_loader) + + # tmetrics = train_evaluator.state.metrics + # vmetrics = evaluator.state.metrics + # for metric in metrics.keys(): + # tm = tmetrics[metric] + # vm = vmetrics[metric] + # if metric == "roccurve": + # tm = [k.tolist() for k in tm] + # vm = [k.tolist() for k in vm] + # if isinstance(tm, torch.Tensor): + # tm = tm.cpu().numpy().tolist() + # vm = vm.cpu().numpy().tolist() + + # history["train"][metric].append(tm) + # history["validation"][metric].append(vm) + + # train_evaluator.run(train_loader) + evaluator.run(val_loader) + + vmetrics = evaluator.state.metrics + for metric in metrics.keys(): + vm = vmetrics[metric] + t_metric = metric + if metric == "roccurve": + vm = [k.tolist() for k in vm] + if isinstance(vm, torch.Tensor): + vm = vm.cpu().numpy().tolist() + + history["validation"][metric].append(vm) + + + + epoch_num = len(history["validation"][t_metric]) + if epoch_num % 20 == 0: + train_evaluator.run(train_loader) + tmetrics = train_evaluator.state.metrics + for metric in metrics.keys(): + tm = tmetrics[metric] + if metric == "roccurve": + tm = [k.tolist() for k in tm] + if isinstance(tm, torch.Tensor): + tm = tm.cpu().numpy().tolist() + + history["train"][metric].append(tm) + else: + tmetrics = {} + tmetrics['mae'] = -1 + + + # for metric in metrics.keys(): + # history["train"][metric].append(tmetrics[metric]) + # history["validation"][metric].append(vmetrics[metric]) + + if config.store_outputs: + history["EOS"] = eos.data + history["trainEOS"] = train_eos.data + dumpjson( + filename=os.path.join(config.output_dir, "history_val.json"), + data=history["validation"], + ) + dumpjson( + filename=os.path.join(config.output_dir, "history_train.json"), + data=history["train"], + ) + if config.progress: + pbar = ProgressBar() + if not classification: + pbar.log_message(f"Val_MAE: {vmetrics['mae']:.4f}") + pbar.log_message(f"Train_MAE: {tmetrics['mae']:.4f}") + else: + pbar.log_message(f"Train ROC AUC: {tmetrics['rocauc']:.4f}") + pbar.log_message(f"Val ROC AUC: {vmetrics['rocauc']:.4f}") + + if config.n_early_stopping is not None: + if classification: + my_metrics = "accuracy" + else: + my_metrics = "neg_mae" + + def default_score_fn(engine): + score = engine.state.metrics[my_metrics] + return score + + es_handler = EarlyStopping( + patience=config.n_early_stopping, + score_function=default_score_fn, + trainer=trainer, + ) + evaluator.add_event_handler(Events.EPOCH_COMPLETED, es_handler) + # optionally log results to tensorboard + if config.log_tensorboard: + + tb_logger = TensorboardLogger( + log_dir=os.path.join(config.output_dir, "tb_logs", "test") + ) + for tag, evaluator in [ + ("training", train_evaluator), + ("validation", evaluator), + ]: + tb_logger.attach_output_handler( + evaluator, + event_name=Events.EPOCH_COMPLETED, + tag=tag, + metric_names=["loss", "mae"], + global_step_transform=global_step_from_engine(trainer), + ) + + trainer.run(train_loader, max_epochs=config.epochs) + + if config.log_tensorboard: + test_loss = evaluator.state.metrics["loss"] + tb_logger.writer.add_hparams(config, {"hparam/test_loss": test_loss}) + tb_logger.close() + if config.write_predictions and classification: + net.eval() + f = open( + os.path.join(config.output_dir, "prediction_results_test_set.csv"), + "w", + ) + f.write("id,target,prediction\n") + targets = [] + predictions = [] + with torch.no_grad(): + ids = test_loader.dataset.ids # [test_loader.dataset.indices] + for dat, id in zip(test_loader, ids): + g, lg, target = dat + out_data = net([g.to(device), lg.to(device)]) + # out_data = torch.exp(out_data.cpu()) + top_p, top_class = torch.topk(torch.exp(out_data), k=1) + target = int(target.cpu().numpy().flatten().tolist()[0]) + + f.write("%s, %d, %d\n" % (id, (target), (top_class))) + targets.append(target) + predictions.append( + top_class.cpu().numpy().flatten().tolist()[0] + ) + f.close() + from sklearn.metrics import roc_auc_score + + print("predictions", predictions) + print("targets", targets) + print( + "Test ROCAUC:", + roc_auc_score(np.array(targets), np.array(predictions)), + ) + + if ( + config.write_predictions + and not classification + and config.model.output_features > 1 + ): + net.eval() + mem = [] + with torch.no_grad(): + ids = test_loader.dataset.ids # [test_loader.dataset.indices] + for dat, id in zip(test_loader, ids): + g, lg, target = dat + out_data = net([g.to(device), lg.to(device)]) + out_data = out_data.cpu().numpy().tolist() + if config.standard_scalar_and_pca: + sc = pk.load(open("sc.pkl", "rb")) + out_data = list( + sc.transform(np.array(out_data).reshape(1, -1))[0] + ) # [0][0] + target = target.cpu().numpy().flatten().tolist() + info = {} + info["id"] = id + info["target"] = target + info["predictions"] = out_data + mem.append(info) + dumpjson( + filename=os.path.join( + config.output_dir, "multi_out_predictions.json" + ), + data=mem, + ) + if ( + config.write_predictions + and not classification + and config.model.output_features == 1 + ): + net.eval() + f = open( + os.path.join(config.output_dir, "prediction_results_test_set.csv"), + "w", + ) + f.write("id,target,prediction\n") + targets = [] + predictions = [] + with torch.no_grad(): + for dat in test_loader: + g, lg, _, target = dat + out_data = net([g.to(device), lg.to(device), lg.to(device)]) + out_data = out_data.cpu().numpy().tolist() + target = target.cpu().numpy().flatten().tolist() + if len(target) == 1: + target = target[0] + targets.append(target) + predictions.append(out_data) + f.close() + from sklearn.metrics import mean_absolute_error + + print( + "Test MAE:", + mean_absolute_error(np.array(targets), np.array(predictions)) * std_train, + "STD train:", + std_train, + ) + if config.store_outputs and not classification: + x = [] + y = [] + for i in history["EOS"]: + x.append(i[0].cpu().numpy().tolist()) + y.append(i[1].cpu().numpy().tolist()) + x = np.array(x, dtype="float").flatten() + y = np.array(y, dtype="float").flatten() + f = open( + os.path.join( + config.output_dir, "prediction_results_train_set.csv" + ), + "w", + ) + # TODO: Add IDs + f.write("target,prediction\n") + for i, j in zip(x, y): + f.write("%6f, %6f\n" % (j, i)) + line = str(i) + "," + str(j) + "\n" + f.write(line) + f.close() + # return history + return np.array(targets) * std_train + mean_train, np.array(predictions) * std_train + mean_train + + diff --git a/benchmarks/matbench_v0.1_iComFormer/train.sh b/benchmarks/matbench_v0.1_iComFormer/train.sh new file mode 100644 index 00000000..a883c5a2 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/train.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +GPU=4 + +fold=0 + +CUDA_VISIBLE_DEVICES=${GPU} \ +python ./train.py \ +--output_dir="../matbench_mp_e_form_epoch500_max_25_lr1e-3_4layer_L1_fold"$fold \ +--max_neighbors=25 \ +--epochs=500 \ +--batch_size=64 \ +--task_name="matbench_mp_e_form" \ +--lr=1e-3 \ +--criterion='l1' \ +--fold_num=$fold \ +--multi_GPU \ \ No newline at end of file diff --git a/benchmarks/matbench_v0.1_iComFormer/train_props.py b/benchmarks/matbench_v0.1_iComFormer/train_props.py new file mode 100644 index 00000000..378fc5c0 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/train_props.py @@ -0,0 +1,304 @@ +"""Helper function for high-throughput GNN trainings.""" +"""Implementation based on the template of ALIGNN.""" +import matplotlib.pyplot as plt + +# import numpy as np +import time +from matformer.train import train_dgl +import os +import numpy as np + +# from sklearn.metrics import mean_absolute_error +plt.switch_backend("agg") + + +def train_prop_model( + prop="", + dataset="dft_3d", + write_predictions=True, + name="pygatt", + save_dataloader=False, + train_ratio=None, + classification_threshold=None, + val_ratio=None, + test_ratio=None, + learning_rate=0.001, + batch_size=None, + scheduler=None, + n_epochs=None, + id_tag=None, + num_workers=None, + weight_decay=None, + edge_input_features=None, + triplet_input_features=None, + embedding_features=None, + hidden_features=None, + output_features=None, + random_seed=None, + n_early_stopping=None, + cutoff=None, + max_neighbors=None, + matrix_input=False, + pyg_input=False, + use_lattice=False, + use_angle=False, + output_dir=None, + neighbor_strategy="k-nearest", + test_only=False, + use_save=True, + mp_id_list=None, + file_name=None, + atom_features="cgcnn", + task_name=None, + save_dir=None, + criterion=None, + multi_GPU=False, + fold_num=None, + model_variant=None, +): + """Train models for a dataset and a property.""" + if scheduler is None: + scheduler = "onecycle" + # scheduler = "none" + if batch_size is None: + batch_size = 64 + if n_epochs is None: + n_epochs = 500 + if num_workers is None: + num_workers = 10 + config = { + "dataset": dataset, + "target": "label", #prop, + "epochs": n_epochs, # 00,#00, + "batch_size": batch_size, # 0, + "weight_decay": 1e-05, + "learning_rate": learning_rate, + "criterion": criterion, #'l1', #"mse", + "optimizer": "adamw", + "scheduler": scheduler, + "save_dataloader": save_dataloader, + "pin_memory": False, + "write_predictions": write_predictions, + "num_workers": num_workers, + "classification_threshold": classification_threshold, + "atom_features": atom_features, + "model": { + "name": name, + }, + } + if n_early_stopping is not None: + config["n_early_stopping"] = n_early_stopping + if cutoff is not None: + config["cutoff"] = cutoff + if max_neighbors is not None: + config["max_neighbors"] = max_neighbors + if weight_decay is not None: + config["weight_decay"] = weight_decay + if edge_input_features is not None: + config["model"]["edge_input_features"] = edge_input_features + if hidden_features is not None: + config["model"]["hidden_features"] = hidden_features + if embedding_features is not None: + config["model"]["embedding_features"] = embedding_features + if output_features is not None: + config["model"]["output_features"] = output_features + if random_seed is not None: + config["random_seed"] = random_seed + if file_name is not None: + config["filename"] = file_name + # if model_name is not None: + # config['model']['name']=model_name + config["matrix_input"] = matrix_input + config["pyg_input"] = pyg_input + config["use_lattice"] = use_lattice + config["use_angle"] = use_angle + config["model"]["use_angle"] = use_angle + config["neighbor_strategy"] = neighbor_strategy + # config["output_dir"] = '.' + if output_dir is not None: + config["output_dir"] = output_dir + + if id_tag is not None: + config["id_tag"] = id_tag + if train_ratio is not None: + config["train_ratio"] = train_ratio + if val_ratio is None: + raise ValueError("Enter val_ratio.") + + if test_ratio is None: + raise ValueError("Enter test_ratio.") + config["val_ratio"] = val_ratio + config["test_ratio"] = test_ratio + if dataset == "jv_3d": + # config["save_dataloader"]=True + config["num_workers"] = 4 + config["pin_memory"] = False + # config["learning_rate"] = 0.001 + # config["epochs"] = 300 + + if dataset == "mp_3d_2020": + config["id_tag"] = "id" + config["num_workers"] = 0 + if dataset == "megnet2": + config["id_tag"] = "id" + config["num_workers"] = 0 + if dataset == "megnet": + config["id_tag"] = "id" + if prop == "e_form" or prop == "gap pbe": + config["n_train"] = 60000 + config["n_val"] = 5000 + config["n_test"] = 4239 + # config["learning_rate"] = 0.01 + # config["epochs"] = 300 + config["num_workers"] = 8 + else: + config["n_train"] = 4664 + config["n_val"] = 393 + config["n_test"] = 393 + if dataset == "oqmd_3d_no_cfid": + config["id_tag"] = "_oqmd_entry_id" + config["num_workers"] = 0 + if dataset == "hmof" and prop == "co2_absp": + config["model"]["output_features"] = 5 + if dataset == "edos_pdos": + if prop == "edos_up": + config["model"]["output_features"] = 300 + elif prop == "pdos_elast": + config["model"]["output_features"] = 200 + else: + raise ValueError("Target not available.") + if dataset == "qm9_std_jctc": + config["id_tag"] = "id" + config["n_train"] = 110000 + config["n_val"] = 10000 + config["n_test"] = 10829 + + # config["batch_size"] = 64 + config["cutoff"] = 5.0 + config["standard_scalar_and_pca"] = False + + if dataset == "qm9_dgl": + config["id_tag"] = "id" + config["n_train"] = 110000 + config["n_val"] = 10000 + config["n_test"] = 10831 + config["standard_scalar_and_pca"] = False + config["batch_size"] = 64 + config["cutoff"] = 5.0 + if config["target"] == "all": + config["model"]["output_features"] = 12 + + # config["max_neighbors"] = 9 + + if dataset == "hpov": + config["id_tag"] = "id" + if dataset == "qm9": + config["id_tag"] = "id" + config["n_train"] = 110000 + config["n_val"] = 10000 + config["n_test"] = 13885 + config["batch_size"] = batch_size + config["cutoff"] = 5.0 + config["max_neighbors"] = 9 + # config['atom_features']='atomic_number' + if prop in ["homo", "lumo", "gap", "zpve", "U0", "U", "H", "G"]: + config["target_multiplication_factor"] = 27.211386024367243 + + if dataset == 'mpf': + config["id_tag"] = "id" + config["n_train"] = 169516 + config["n_val"] = 9417 + config["n_test"] = 9417 + + if test_only: + t1 = time.time() + result = train_dgl(config, test_only=test_only, use_save=use_save, mp_id_list=mp_id_list) + t2 = time.time() + print("test mae=", result) + print("Toal time:", t2 - t1) + print() + print() + print() + else: + # t1 = time.time() + # result = train_dgl(config, use_save=use_save, mp_id_list=mp_id_list) + # t2 = time.time() + # print("train=", result["train"]) + # print("validation=", result["validation"]) + # print("Toal time:", t2 - t1) + # print() + # print() + # print() + + from matbench.bench import MatbenchBenchmark + + mb = MatbenchBenchmark(subset=[task_name], autoload=False) + + print(f"Running task: {task_name} Fold: {fold_num}") + + if multi_GPU: + + task = next(iter(mb.tasks)) + assert (task.dataset_name == task_name) + task.load() + + train_inputs, train_outputs = task.get_train_and_val_data(fold_num) + test_inputs, test_outputs = task.get_test_data(fold_num, include_target=True) + + + # train_label = train_outputs.values + # print("=0: ", np.sum((train_label < 1e-13)) / train_label.shape[0]) + # print("< 1e-6: ", np.sum((train_label < 1e-6)) / train_label.shape[0]) + # print("1e-6 ~ 1e-2: ", np.sum(np.logical_and((train_label > 1e-6), (train_label < 1e-2))) / train_label.shape[0]) + # print("1e-2 ~ 1e-1: ", np.sum(np.logical_and((train_label > 1e-2), (train_label < 1e-1))) / train_label.shape[0]) + # print("1e-1 ~ 1: ", np.sum(np.logical_and((train_label > 1e-1), (train_label < 1))) / train_label.shape[0]) + # print("> 1: ", np.sum((train_label > 1)) / train_label.shape[0]) + + # import pdb; pdb.set_trace() + + target, predictions = train_dgl(config, + use_save=use_save, + mp_id_list=mp_id_list, + train_inputs=train_inputs, + train_outputs=train_outputs, + test_inputs=test_inputs, + test_outputs=test_outputs, + model_variant=model_variant) + + np.save(os.path.join(output_dir, f"result_fold_{fold_num}.npy"), predictions) + + else: + for task in mb.tasks: + task.load() + for fold in task.folds: + + train_inputs, train_outputs = task.get_train_and_val_data(fold) + test_inputs, test_outputs = task.get_test_data(fold, include_target=True) + + # import pdb; pdb.set_trace() + target, predictions = train_dgl(config, + use_save=use_save, + mp_id_list=mp_id_list, + train_inputs=train_inputs, + train_outputs=train_outputs, + test_inputs=test_inputs, + test_outputs=test_outputs, + model_variant=model_variant) + + # import pdb; pdb.set_trace() + # print("test set error", (target - test_outputs).sum()) + # import pdb; pdb.set_trace() + + # Predict on the testing data + # Your output should be a pandas series, numpy array, or python iterable + # where the array elements are floats or bools + # predictions = my_model.predict(test_inputs) + + # Record your data! + task.record(fold, predictions) + + # Save your results + mb.to_file(os.path.join(output_dir, f"{task_name}.json.gz")) + + + diff --git a/benchmarks/matbench_v0.1_iComFormer/utils.py b/benchmarks/matbench_v0.1_iComFormer/utils.py new file mode 100644 index 00000000..c9cce671 --- /dev/null +++ b/benchmarks/matbench_v0.1_iComFormer/utils.py @@ -0,0 +1,45 @@ +"""Shared pydantic settings configuration.""" +"""Implementation based on the template of ALIGNN.""" +import json +from pathlib import Path +from typing import Union +import matplotlib.pyplot as plt + +from pydantic import BaseSettings as PydanticBaseSettings + + +class BaseSettings(PydanticBaseSettings): + """Add configuration to default Pydantic BaseSettings.""" + + class Config: + """Configure BaseSettings behavior.""" + + extra = "forbid" + use_enum_values = True + env_prefix = "jv_" + + +def plot_learning_curve( + results_dir: Union[str, Path], key: str = "mae", plot_train: bool = False +): + """Plot learning curves based on json history files.""" + if isinstance(results_dir, str): + results_dir = Path(results_dir) + + with open(results_dir / "history_val.json", "r") as f: + val = json.load(f) + + p = plt.plot(val[key], label=results_dir.name) + + if plot_train: + # plot the training trace in the same color, lower opacity + with open(results_dir / "history_train.json", "r") as f: + train = json.load(f) + + c = p[0].get_color() + plt.plot(train[key], alpha=0.5, c=c) + + plt.xlabel("epochs") + plt.ylabel(key) + + return train, val