diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfa85c9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# --- .gitignore 内容开始 --- + +# 1. 忽略 LAMMPS 日志和轨迹文件 +log.lammps +*.log +*.lammpstrj + +# 2. 忽略数据结构文件 +*.data +POSCAR +CH4.txt +*pt +*.pt +*.pkl +# 3. 忽略大模型权重文件 (通常不建议传大文件到 git,除非你需要) + +# 4. 忽略 Python 编译缓存和打包文件 +__pycache__/ +*.egg-info/ +build/ +dist/ +*.zip +lammps/ +# --- .gitignore 内容结束 --- diff --git a/README.md b/README.md index a32fea3..0bb0e6e 100644 --- a/README.md +++ b/README.md @@ -6,17 +6,14 @@ We present **AlphaNet**, a local frame-based equivariant model designed to tackle the challenges of achieving both accurate and efficient simulations for atomistic systems. **AlphaNet** enhances computational efficiency and accuracy by leveraging the local geometric structures of atomic environments through the construction of equivariant local frames and learnable frame transitions. And inspired by Quantum Mechanics, AlphaNet **introduces efficient multi-body message passing by using contraction of matrix product states** rather than common 2-body message passing. Notably, AlphaNet offers one of the best trade-offs between computational efficiency and accuracy among existing models. Moreover, AlphaNet exhibits scalability across a broad spectrum of system and dataset sizes, affirming its versatility. markdown -## Update Log (v0.1.2) +## Update Log (v0.1.2-beta) ### Major Changes -1. **Added new 2 pretrained models** - - Provide a pretrained model for materials: **AlphaNet-MATPES-r2scan** and our first pretrained model for catlysis: **AlphaNet-AQCAT25**, see them in the [pretrained](./pretrained) folder. - - Users can **convert the checkpoint trained in torch to our JAX model** - -2. **Fixed some bugs** - - Support non-periodic boundary conditions in our ase calculator. - - Fixed errors in float64 +1. **Add lammps mliap interface** +2. **Slight change of model arch** +3. **Add finetune option** + ## Installation Guide @@ -84,7 +81,11 @@ alpha-train example.json # use --help to see more functions, like multi-gpu trai ```bash alpha-conv -i in.ckpt -o out.ckpt # use --help to see more functions ``` -3. Evaluate a model and draw diagonal plot: +2. Finetune a converted ckpt: +```bash +alpha-train example.json --finetune /path/to/your.ckpt +``` +4. Evaluate a model and draw diagonal plot: ```bash alpha-eval -c example.json -m /path/to/ckpt # use --help to see more functions ``` @@ -142,67 +143,17 @@ print(atoms.get_potential_energy()) ``` -### Using AlphaNet in JAX -1. Installation - ```bash - pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ - ``` - This is just for reference. JAX installation may be tricky, please get more information in [JAX](https://docs.jax.dev/en/latest/installation.html) and its github issues. - - Currently I suggest **version>=0.4 <=0.4.10 or >=0.4.30 <=0.5 or ==0.6.2** - - Install flax and haiku - ```bash - pip install matscipy - pip install flax - pip install -U dm-haiku - ``` - -2. Converted checkpoints: - - See pretrained directory - -3. Convert a self-trained ckpt - - First from torch to flax: - ```bash - python scripts/conv_pt2flax.py #need to modify the path in it. - ``` - Then from flax to haiku: - - ```bash - python scripts/flax2haiku.py #need to modify the path in it. - ``` - -4. Performance: - - The output (energy forces stress) difference from torch model would below 0.001. I ran speed tests on a 4090 GPU, system size from 4 to 300, and get a **2.5x to 3x** speed up. - - Please note jax model need to be compiled first, so the first run could take a few seconds or minutes, but would be pretty fast after that. - -## Dataset Download - -[The Defected Bilayer Graphene Dataset](https://zenodo.org/records/10374206) - -[The Formate Decomposition on Cu Dataset](https://archive.materialscloud.org/record/2022.45) - -[The Zeolite Dataset](https://doi.org/10.6084/m9.figshare.27800211) - -[The OC dataset](https://opencatalystproject.org/) - -[The MPtrj dataset](https://matbench-discovery.materialsproject.org/data) - ## Pretrained Models -Current pretrained models: +Current pretrained models (due to the arch changes, previous pretrained models would need update, which will be done asap): For materials: -- [AlphaNet-MPtrj-v1](pretrained/MPtrj): A model trained on the MpTrj dataset. -- [AlphaNet-oma-v1](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj. -- [AlphaNet-MATPES-r2scan](pretrained/MATPES): A model trained on the MATPES-r2scan dataset. -For surfaces adsorbtion and reactions: -- [AlphaNet-AQCAT25](pretrained/AQCAT25): A model trained on the AQCAT25 dataset. +- [AlphaNet-oma-v1.5](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj. + +## Use AlphaNet in LAMMPS + +See [mliap_lammps](mliap_lammps.md) ## License @@ -222,3 +173,6 @@ We thank all contributors and the community for their support. Please open an is + + + diff --git a/alphanet/cli.py b/alphanet/cli.py index 22f0253..d0d3cbd 100644 --- a/alphanet/cli.py +++ b/alphanet/cli.py @@ -57,7 +57,8 @@ def display_config_table(main_config, runtime_config): @click.option("--num_devices", type=int, default=1, help="GPUs per node") @click.option("--resume", is_flag=True, help="Resume training from checkpoint") @click.option("--ckpt_path", type=click.Path(), default=None, help="Path to checkpoint file") -def main(config, num_nodes, num_devices, resume, ckpt_path): +@click.option("--finetune", type=click.Path(exists=True), default=None, help="Path to pretrained checkpoint for finetuning (resets optimizer)") +def main(config, num_nodes, num_devices, resume, ckpt_path, finetune): with open(config, "r") as f: mconfig = json.load(f) @@ -67,7 +68,8 @@ def main(config, num_nodes, num_devices, resume, ckpt_path): "num_nodes": num_nodes, "num_devices": num_devices, "resume": resume, - "ckpt_path": ckpt_path + "ckpt_path": ckpt_path, + "finetune_path": finetune } display_header() diff --git a/alphanet/config.py b/alphanet/config.py index 5bf9354..e94bdde 100644 --- a/alphanet/config.py +++ b/alphanet/config.py @@ -2,7 +2,7 @@ import subprocess import json -import torch +#import torch from typing import Literal, Dict, Optional from pydantic_settings import BaseSettings @@ -22,6 +22,7 @@ class TrainConfig(BaseSettings): batch_size: int = 32 vt_batch_size: int = 32 lr: float = 0.0005 + optimizer: str = "radam" lr_decay_factor: float = 0.5 lr_decay_step_size: int = 150 weight_decay: float = 0 @@ -86,7 +87,13 @@ class AlphaConfig(BaseSettings): has_norm_after_flag: bool = False reduce_mode: str = "sum" zbl: bool = False - device: torch.device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu") + zbl_w: Optional[list] = [0.187,0.3769,0.189,0.081,0.003,0.037,0.0546,0.0715] + zbl_b: Optional[list] = [3.20,1.10,0.102,0.958,1.28,1.14,1.69,5] + zbl_gamma: float = 1.001 + zbl_alpha: float = 0.6032 + zbl_E2: float = 14.399645478425 + zbl_A0: float = 0.529177210903 + device: str = "cuda" diff --git a/alphanet/create_lammps_model.py b/alphanet/create_lammps_model.py new file mode 100644 index 0000000..7d29be3 --- /dev/null +++ b/alphanet/create_lammps_model.py @@ -0,0 +1,91 @@ +import argparse +import os +import torch +from pathlib import Path + +# Import the AlphaNet model wrapper and config +from alphanet.models.model import AlphaNetWrapper +from alphanet.config import All_Config + +# Import the Python-level LAMMPS interface class +try: + from alphanet.infer.lammps_mliap_alphanet import LAMMPS_MLIAP_ALPHANET +except ImportError: + print("Could not import LAMMPS_MLIAP_ALPHANET.") + print("Please ensure 'alphanet/infer/lammps_mliap_alphanet.py' exists.") + exit(1) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert an AlphaNet model to LAMMPS ML-IAP format (Python Pickle)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", "-c", required=True, type=str, + help="Path to the model configuration JSON file", + ) + parser.add_argument( + "--checkpoint", "-m", required=True, type=str, + help="Path to the trained model checkpoint (.ckpt)", + ) + parser.add_argument( + "--output", "-o", required=True, type=str, + help="Output path to save the model (e.g., alphanet_lammps.pt)", + ) + parser.add_argument( + "--device", type=str, default="cpu", + help="Device to load the model on ('cpu' or 'cuda')", + ) + parser.add_argument( + "--dtype", type=str, default="float64", + choices=["float32", "float64"], + help="Data type for the model", + ) + return parser.parse_args() + +def main(): + args = parse_args() + + device = torch.device(args.device) + + print(f"1. Loading configuration from {args.config}...") + config_obj = All_Config().from_json(args.config) + + config_obj.model.dtype = "64" if args.dtype == "float64" else "32" + + print(f"2. Initializing AlphaNetWrapper (precision: {args.dtype}, device: {args.device})...") + model_wrapper = AlphaNetWrapper(config_obj.model) + + print(f"3. Loading weights from {args.checkpoint}...") + ckpt = torch.load(args.checkpoint, map_location=device) + + if 'state_dict' in ckpt: + state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()} + model_wrapper.model.load_state_dict(state_dict, strict=False) + else: + model_wrapper.load_state_dict(ckpt, strict=False) + + if args.dtype == "float64": + model_wrapper.double() + else: + model_wrapper.float() + + model_wrapper.to(device).eval() + + print("4. Creating LAMMPS ML-IAP Interface Object...") + lammps_interface_object = LAMMPS_MLIAP_ALPHANET(model_wrapper) + + if device.type == 'cuda': + lammps_interface_object.model.cuda() + + print(f"5. Saving Python object to {args.output}...") + # Using standard torch.save for Python pickle compatibility + torch.save(lammps_interface_object, args.output) + + print("\n--- Success ---") + print(f"Created LAMMPS model file: {args.output}") + print("Usage in LAMMPS: pair_style mliap model/python ...") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/alphanet/infer/calc.py b/alphanet/infer/calc.py index 35def11..9208aee 100644 --- a/alphanet/infer/calc.py +++ b/alphanet/infer/calc.py @@ -29,6 +29,8 @@ def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs): Calculator.__init__(self, **kwargs) # --- Model Loading --- + if precision == "64": + config.dtype = '64' if ckpt_path.endswith('ckpt'): self.model = AlphaNetWrapper(config).to(torch.device(device)) # Load state dict, ignoring mismatches if any @@ -42,7 +44,7 @@ def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs): self.precision = torch.float32 if precision == "32" else torch.float64 if precision == "64": - self.model.double() + self.model.double() self.model.eval() # Set model to evaluation mode self.model.to(self.device) diff --git a/alphanet/infer/lammps_mliap_alphanet.py b/alphanet/infer/lammps_mliap_alphanet.py new file mode 100644 index 0000000..073e0fa --- /dev/null +++ b/alphanet/infer/lammps_mliap_alphanet.py @@ -0,0 +1,431 @@ +import logging +import math +from typing import Dict, Tuple, Optional +from math import pi + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch_scatter import scatter +from ase.data import chemical_symbols + +# Import custom model modules +from alphanet.models.alphanet import AlphaNet +from alphanet.models.model import AlphaNetWrapper +# from alphanet.models.graph import GraphData, get_max_neighbors_mask # Uncomment if needed + +# Try to import LAMMPS interface +try: + from lammps.mliap.mliap_unified_abc import MLIAPUnified +except ImportError: + class MLIAPUnified: + def __init__(self): pass + print("Warning: LAMMPS MLIAP-Unified interface not found. Creating dummy class.") + + +class LAMMPS_MP(torch.autograd.Function): + """ + Handles MPI communication for gradients between Local and Ghost atoms within LAMMPS. + """ + @staticmethod + def forward(ctx, *args): + feats, data = args + ctx.vec_len = feats.shape[-1] + ctx.data = data + + out = torch.empty_like(feats) + if not feats.is_contiguous(): + feats = feats.contiguous() + + # Forward exchange: Ghost atoms get data from their Local owners on other procs + data.forward_exchange(feats, out, ctx.vec_len) + return out + + @staticmethod + def backward(ctx, *grad_outputs): + (grad,) = grad_outputs + + gout = grad.clone() + if not gout.is_contiguous(): + gout = gout.contiguous() + + # Reverse exchange: Sum gradients from Ghost atoms back to their Local owners + ctx.data.reverse_exchange(grad, gout, ctx.vec_len) + + return gout, None + + +class AlphaNetEdgeForcesWrapper(torch.nn.Module): + """ + Wrapper for AlphaNet to compute forces via edge gradients directly. + """ + def __init__(self, model: AlphaNet): + super().__init__() + + # Copy submodules + self.z_emb = model.z_emb + self.z_emb_ln = model.z_emb_ln + self.radial_emb = model.radial_emb + self.radial_lin = model.radial_lin + self.neighbor_emb = model.neighbor_emb + self.S_vector = model.S_vector + self.lin = model.lin + self.message_layers = model.message_layers + self.FTEs = model.FTEs + self.last_layer = model.last_layer + self.last_layer_quantum = model.last_layer_quantum + + # Copy Parameters + self.a = model.a + self.b = model.b + self.kernel1 = model.kernel1 + + if isinstance(model.kernels_real, torch.Tensor): + self.kernels_real = nn.ParameterList([nn.Parameter(p) for p in model.kernels_real]) + else: + self.kernels_real = model.kernels_real + + if isinstance(model.kernels_imag, torch.Tensor): + self.kernels_imag = nn.ParameterList([nn.Parameter(p) for p in model.kernels_imag]) + else: + self.kernels_imag = model.kernels_imag + + # Metadata + self.cutoff = model.cutoff + self.pi = model.pi + self.eps = 1e-9 + self.hidden_channels = model.hidden_channels + self.chi1 = model.chi1 + self.complex_type = model.complex_type + self.rcutfac = float(model.cutoff) + + self.register_buffer("atomic_numbers", torch.arange(1, 95)) + self.eval() + + def handle_lammps(self, tensor: Tensor, lammps_class: Optional[object], natoms: Tensor) -> Tensor: + """ + Syncs tensor data for ghost atoms if running inside LAMMPS with MPI. + """ + if lammps_class is None: + return tensor + + n_local = int(natoms[0]) + n_total = int(natoms[1]) + n_current = tensor.size(0) + current_dim = tensor.size(1) + + # Pad if necessary + if n_current == n_local and n_total > n_local: + padding = torch.zeros((n_total - n_local, current_dim), dtype=tensor.dtype, device=tensor.device) + tensor_full = torch.cat([tensor, padding], dim=0) + else: + tensor_full = tensor + + if n_total > n_local: + if tensor_full.device != self.a.device: + tensor_full = tensor_full.to(self.a.device) + + orig_dtype = tensor_full.dtype + if orig_dtype != torch.float64: + tensor_full = tensor_full.to(torch.float64) + + target_dim = getattr(lammps_class, "ndescriptors", current_dim) + + # Pad dimension if needed + if current_dim < target_dim: + pad_width = target_dim - current_dim + col_padding = torch.zeros((tensor_full.size(0), pad_width), dtype=tensor_full.dtype, device=tensor_full.device) + tensor_ready = torch.cat([tensor_full, col_padding], dim=1) + else: + tensor_ready = tensor_full + + if not tensor_ready.is_contiguous(): + tensor_ready = tensor_ready.contiguous() + + if tensor_ready.is_cuda: + torch.cuda.synchronize() + + # Sync data using custom autograd function + tensor_synced = LAMMPS_MP.apply(tensor_ready, lammps_class) + + if current_dim < target_dim: + tensor_synced = tensor_synced[:, :current_dim] + if tensor_synced.dtype != orig_dtype: + tensor_synced = tensor_synced.to(orig_dtype) + return tensor_synced + + return tensor_full + + def forward(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, Tensor]: + pos = data["positions"] + z = data["node_attrs"] + edge_index = data["edge_index"] + edge_vec = data["vectors"] + + # Ensure vectors require grad for force computation + if not edge_vec.requires_grad: + edge_vec = edge_vec.requires_grad_() + + lammps_ptr = data.get("lammps_ptr") + natoms_info = data["natoms"] + + n_local = int(natoms_info[0]) + n_total = pos.size(0) + + dist = torch.linalg.norm(edge_vec, dim=1) + z_emb = self.z_emb_ln(self.z_emb(z)) + radial_emb = self.radial_emb(dist) + radial_hidden = self.radial_lin(radial_emb) + rbounds = 0.5 * (torch.cos(dist * self.pi / self.cutoff) + 1.0) + radial_hidden = rbounds.unsqueeze(-1) * radial_hidden + + s = self.neighbor_emb(z, z_emb, edge_index, radial_hidden) + + vec = torch.zeros(n_local, 3, s.size(1), device=s.device, dtype=s.dtype) + s = s[:n_local] + + j = edge_index[0] + i = edge_index[1] + edge_diff = edge_vec / (dist.unsqueeze(1) + self.eps) + + edge_vec_mean = scatter(edge_vec, i, reduce='mean', dim=0, dim_size=n_total) + edge_cross = torch.cross(edge_vec, edge_vec_mean[i]) + edge_vertical = torch.cross(edge_diff, edge_cross) + edge_frame = torch.cat((edge_diff.unsqueeze(-1), edge_cross.unsqueeze(-1), edge_vertical.unsqueeze(-1)), dim=-1) + + # Sync initial s features + s_full = self.handle_lammps(s, lammps_ptr, natoms_info) + + S_i_j = self.S_vector(s_full, edge_diff.unsqueeze(-1), edge_index, radial_hidden)[:n_local] + sij_flat = S_i_j.reshape(n_local, -1) + sij_full_flat = self.handle_lammps(sij_flat, lammps_ptr, natoms_info) + S_i_j = sij_full_flat.reshape(n_total, 3, -1) + + scalrization1 = torch.sum(S_i_j[i].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) + scalrization2 = torch.sum(S_i_j[j].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) + scalrization1[:, 1, :] = torch.square(scalrization1[:, 1, :].clone()) + scalrization2[:, 1, :] = torch.square(scalrization2[:, 1, :].clone()) + + scalar3 = (self.lin(torch.permute(scalrization1, (0, 2, 1))) + + torch.permute(scalrization1, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt(self.hidden_channels) + scalar4 = (self.lin(torch.permute(scalrization2, (0, 2, 1))) + + torch.permute(scalrization2, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt(self.hidden_channels) + + edge_weight = torch.cat((scalar3, scalar4), dim=-1) * rbounds.unsqueeze(-1) + edge_weight = torch.cat((edge_weight, radial_hidden, radial_emb), dim=-1) + + quantum = torch.einsum('ik,bi->bk', self.kernel1, z_emb[:n_local]) + real, imagine = torch.split(quantum, self.chi1, dim=-1) + quantum = torch.complex(real, imagine) + + rope: Optional[Tensor] = None + + # Message Passing Layers + for i_layer, (message_layer, fte, kernel_real, kernel_imag) in enumerate(zip( + self.message_layers, self.FTEs, self.kernels_real, self.kernels_imag + )): + # Sync s + s_full = self.handle_lammps(s, lammps_ptr, natoms_info) + + # Sync vec + vec_flat = vec.reshape(n_local, -1) + vec_full_flat = self.handle_lammps(vec_flat, lammps_ptr, natoms_info) + vec_full = vec_full_flat.reshape(n_total, 3, -1) + + # Sync rope + if rope is not None: + if rope.is_complex(): + rope_real = torch.view_as_real(rope).flatten(1) + rope_full_real = self.handle_lammps(rope_real, lammps_ptr, natoms_info) + rope_full = torch.view_as_complex(rope_full_real.view(n_total, -1, 2)) + else: + rope_full = self.handle_lammps(rope, lammps_ptr, natoms_info) + else: + rope_full = None + + new_rope_full, ds_full, dvec_full = message_layer(s_full, vec_full, edge_index, radial_emb, edge_weight, edge_diff, rope_full) + + rope = new_rope_full[:n_local] + ds = ds_full[:n_local] + dvec = dvec_full[:n_local] + s = s + ds + vec = vec + dvec + + kerneli = torch.complex(kernel_real, kernel_imag) + quantum = torch.einsum('ikl,bi,bl->bk', kerneli, s.to(self.complex_type), quantum) + quantum = quantum / (self.eps + quantum.abs().to(self.complex_type)) + ds_fte, dvec_fte = fte(s, vec) + s = s + ds_fte + vec = vec + dvec_fte + + # Final Readout + s_per_atom = self.last_layer(s) + self.last_layer_quantum(torch.cat([quantum.real, quantum.imag], dim=-1)) / self.chi1 + node_energy = (self.a[z[:n_local]].unsqueeze(1) * s_per_atom + self.b[z[:n_local]].unsqueeze(1)).squeeze(-1) + + if node_energy.shape[0] > n_local: + s_total = node_energy[:n_local].sum() + else: + s_total = node_energy.sum() + + # Compute Gradients (Force = -dE/dr, computed here as gradients w.r.t edge vectors) + if s_total.grad_fn is not None: + grads = torch.autograd.grad( + outputs=[s_total], + inputs=[edge_vec], + grad_outputs=[torch.ones_like(s_total)], + retain_graph=False, + create_graph=False, + allow_unused=True, + )[0] + if grads is None: + grads = torch.zeros_like(edge_vec) + else: + grads = torch.zeros_like(edge_vec) + + pair_forces = grads + + return s_total, node_energy, pair_forces + + +class LAMMPS_MLIAP_ALPHANET(MLIAPUnified): + def __init__(self, model_wrapper: AlphaNetWrapper, **kwargs): + super().__init__() + + internal_model = model_wrapper.model + internal_model.double() + internal_model.eval() + edge_wrapper = AlphaNetEdgeForcesWrapper(internal_model).eval() + edge_wrapper.double() + + self.model = edge_wrapper + self.element_types = [chemical_symbols[i] for i in range(1, 95)] + self.num_species = 94 + self.rcutfac = 0.5 * float(model_wrapper.model.cutoff) + self.hidden_dim = internal_model.hidden_channels + + target_dim = 3 * self.hidden_dim + self.ndescriptors = target_dim + self.nparams = target_dim + self.dtype = model_wrapper.precision + self.device = "cpu" + self.initialized = False + self.step = 0 + + def _initialize_device(self, data): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + self.model = self.model.to(self.device) + self.model.eval() + self.initialized = True + + def compute_forces(self, data): + natoms = data.nlocal + ntotal = data.ntotal + nghosts = ntotal - natoms + npairs = data.npairs + + if not self.initialized: + self.model.to(self.dtype) + self._initialize_device(data) + + if hasattr(data.elems, 'get'): + elems_np = data.elems.get() + else: + elems_np = data.elems + + species = torch.as_tensor(elems_np, dtype=torch.int64, device=self.device) + self.step += 1 + + if natoms == 0 or npairs <= 1: + return + + # 1. Prepare data batch + batch = self._prepare_batch(data, natoms, nghosts, species) + + # 2. Forward pass + batch["vectors"].requires_grad_(True) + total_energy, atom_energies, pair_forces = self.model(batch) + + if self.device.type == "cuda": + torch.cuda.synchronize() + + # 3. Update LAMMPS with energies and forces + self._update_lammps_data(data, atom_energies, pair_forces, natoms, total_energy) + + def _update_lammps_data(self, data, atom_energies, pair_forces, natoms, total_energy): + if self.dtype == torch.float32: + pair_forces = pair_forces.double() + atom_energies = atom_energies.double() + total_energy = total_energy.double() + + atom_energies_cpu = atom_energies[:natoms].detach().cpu().numpy() + + # Update atom energies + if hasattr(data.eatoms, 'get'): + try: data.eatoms[:natoms] = atom_energies_cpu + except: pass + else: + try: + eatoms_tensor = torch.as_tensor(data.eatoms) + eatoms_tensor[:natoms].copy_(torch.from_numpy(atom_energies_cpu)) + except: + pass + + data.energy = total_energy.item() + + final_forces = pair_forces.detach() + + if not final_forces.is_contiguous(): + final_forces = final_forces.contiguous() + + # Update pair forces (GPU or CPU) + if self.device.type == 'cuda': + try: + from torch.utils.dlpack import to_dlpack + from cupy import from_dlpack + force_cupy = from_dlpack(to_dlpack(final_forces)) + data.update_pair_forces_gpu(force_cupy) + except ImportError: + data.update_pair_forces_gpu(final_forces.cpu().numpy()) + else: + final_forces_np = final_forces.numpy() + data.update_pair_forces_gpu(final_forces_np) + + def _prepare_batch(self, data, natoms, nghosts, species) -> Dict[str, object]: + positions = torch.zeros((natoms + nghosts, 3), dtype=self.dtype, device=self.device) + node_attrs = species + 1 + batch_tensor = torch.zeros(natoms, dtype=torch.int64, device=self.device) + natoms_tensor = torch.tensor([natoms, natoms+nghosts], dtype=torch.int64, device=self.device) + + if hasattr(data.rij, 'get'): rij_data = data.rij.get() + else: rij_data = data.rij + + if hasattr(data.pair_i, 'get'): pair_i = data.pair_i.get() + else: pair_i = data.pair_i + + if hasattr(data.pair_j, 'get'): pair_j = data.pair_j.get() + else: pair_j = data.pair_j + + rij_tensor = torch.as_tensor(rij_data, dtype=self.dtype, device=self.device) + target_tensor = torch.as_tensor(pair_i, dtype=torch.int64, device=self.device) + source_tensor = torch.as_tensor(pair_j, dtype=torch.int64, device=self.device) + + edge_index = torch.stack([source_tensor, target_tensor], dim=0) + + return { + "positions": positions, + "vectors": rij_tensor, + "node_attrs": node_attrs, + "edge_index": edge_index, + "batch": batch_tensor, + "natoms": natoms_tensor, + "lammps_ptr": data, + } + + def compute_descriptors(self, data: Dict[str, Tensor]) -> None: + pass + + def compute_gradients(self, data: Dict[str, Tensor]) -> None: + pass \ No newline at end of file diff --git a/alphanet/models/__init__.py b/alphanet/models/__init__.py index 83c1adf..72a417a 100644 --- a/alphanet/models/__init__.py +++ b/alphanet/models/__init__.py @@ -1,2 +1,2 @@ -from .alphanet import AlphaNet +#from .alphanet import AlphaNet #from .alpha_flax import AlphaNet_flax \ No newline at end of file diff --git a/alphanet/models/alpha_haiku.py b/alphanet/models/alpha_haiku.py index ec60ea8..fca7c87 100644 --- a/alphanet/models/alpha_haiku.py +++ b/alphanet/models/alpha_haiku.py @@ -12,6 +12,8 @@ import math from functools import partial +from alphanet.models.zbl_jax import zbl_interaction, get_default_zbl_params + class Config: def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -496,58 +498,38 @@ def __call__(self, data): s = self.last_layer(s) + self.last_layer_quantum(quantum_features) / self.config.main_chi1 a_values = a[z] b_values = b[z] - V_graph = 0 - if self.zbl: - r_e = dist - Z_j = z[j] - Z_i = z[i] - - w = self.fzbl_w # (M,) - b = self.fzbl_b # (M,) - gamma = self.fzbl_gamma - alpha = self.fzbl_alpha - E2 = self.fzbl_E2 - A0 = self.fzbl_A0 - - # compute screening length a per edge: a = gamma * 0.8854 * a0 / (Z1^alpha + Z2^alpha) - denom = jnp.power(Z_j, alpha) + jnp.power(Z_i, alpha) # (E,) - denom = jnp.clip(denom, a_min=1e-12) - a_vals = gamma * 0.8854 * A0 / denom # (E,) - x = r_e / a_vals # (E,) - - # compute phi(x) = sum_i w_i * exp(-b_i * x) (vectorized) - # exp(- x[:,None] * b[None,:]) -> (E, M) - exp_terms = jnp.exp(- x[:, jnp.newaxis] * b[jnp.newaxis, :]) # (E, M) - phi_vals = exp_terms @ w # (E,) - - # pair potential per edge: V_e = Z1*Z2 * E2 * phi / r - V_edge = (Z_j * Z_i * E2) * (phi_vals / r_e) # (E,) - r_cut = 1.0 # you can make this self.fzbl_rcut buffer if you want configurable value - - # compute taper coefficient: cosine cutoff (smooth) - # for r in [0, r_cut]: c = 0.5*(cos(pi * r / r_cut) + 1) - # for r >= r_cut: c = 0 - # for safety, clamp r/r_cut in [0, 1] - xrc = jnp.clip(r_e / r_cut, 0.0, 1.0) # (E,) - # cosine taper - c = 0.5 * (jnp.cos(jnp.pi * xrc) + 1.0) # (E,) - # enforce zero beyond r_cut explicitly (cos already gives 0 at x=1 but clamp keeps numeric safe) - c = jnp.where(r_e >= r_cut, jnp.zeros_like(c), c) - - # apply taper to edge potential - V_edge = V_edge * c + ml_energy = a_values[z] * s.squeeze() + b_values[z] + if self.config.zbl: + # 获取参数 (这里假设使用默认值,如果 config 有则从 config 读) + zbl_params = get_default_zbl_params(self.dtype) + + # 也可以选择将它们注册为不可训练的 hk.parameter 或者常量 + # ... + + V_edge = zbl_interaction( + dist, z[i], z[j], + zbl_params['w'], zbl_params['b'], + zbl_params['gamma'], zbl_params['alpha'], + zbl_params['E2'], zbl_params['A0'] + ) + + # 【关键】原子能量分摊 + # i 是 target 索引 + num_atoms = z.shape[0] + zbl_per_atom = jax.ops.segment_sum(V_edge, i, num_segments=num_atoms) * 0.5 - # aggregate edge energies to graph-level using jax.ops.segment_sum - # Note: JAX uses segment_sum instead of scatter_add - graph_idx = batch[i] # map receiver node -> graph index (E,) - V_graph = jax.ops.segment_sum(V_edge, graph_idx, num_segments=1) / 2.0 + # 加到总原子能量 + total_atom_energy = ml_energy + zbl_per_atom + else: + total_atom_energy = ml_energy if s.ndim == 2: s = a_values[:, None] * s + b_values[:, None] else: s = a_values * s + b_values s = s[:, None] - s = jnp.sum(s)+V_graph#jax.ops.segment_sum(s, batch, num_segments=1)+ Vgraph - return jnp.squeeze(s) + s_total = jax.ops.segment_sum(total_atom_energy, batch, num_segments=1) # 假设 batch_size=1 用于推理 + + return s_total.squeeze() diff --git a/alphanet/models/alphanet.py b/alphanet/models/alphanet.py index 52dec7b..242ab51 100644 --- a/alphanet/models/alphanet.py +++ b/alphanet/models/alphanet.py @@ -1,41 +1,72 @@ - import math from math import pi -from typing import Optional, Tuple, List, NamedTuple -from typing import Literal +from typing import Optional, Tuple, List + import torch -from torch import nn -from torch import Tensor +from torch import nn, Tensor from torch.nn import Embedding from torch_geometric.nn.conv import MessagePassing -from torch_scatter import scatter, scatter_add + from alphanet.models.graph import GraphData +import numpy as np + + +def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + """ + Drop-in replacement for torch_scatter.scatter using native PyTorch functions. + """ + if out is not None: + dim_size = out.size(dim) + else: + if dim_size is None: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + + out_size = list(src.size()) + out_size[dim] = dim_size + + if index.dim() != src.dim(): + curr_dims = index.dim() + target_dims = src.dim() + for _ in range(target_dims - curr_dims): + index = index.unsqueeze(-1) + index = index.expand_as(src) + + reduce = reduce.lower() + + if reduce in ['sum', 'add']: + if out is None: + out = torch.zeros(out_size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + + if reduce == 'mean': + mode = 'mean' + init_val = 0.0 + elif reduce in ['min', 'amin']: + mode = 'amin' + init_val = float('inf') + elif reduce in ['max', 'amax']: + mode = 'amax' + init_val = float('-inf') + else: + raise ValueError(f"Unknown reduce mode: {reduce}") + + if out is None: + out = torch.full(out_size, init_val, dtype=src.dtype, device=src.device) + + out.scatter_reduce_(dim, index, src, reduce=mode, include_self=False) + return out class rbf_emb(nn.Module): r_max: float prefactor: float - def __init__(self, num_basis=8, r_max = 5.0, trainable=True): - r"""Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123 - - - Parameters - ---------- - r_max : float - Cutoff radius - - num_basis : int - Number of Bessel Basis functions - - trainable : bool - Train the :math:`n \pi` part or not. - """ + def __init__(self, num_basis=8, r_max=5.0, trainable=True): super(rbf_emb, self).__init__() - self.trainable = trainable self.num_basis = num_basis - self.r_max = r_max self.prefactor = 2.0 / self.r_max @@ -48,61 +79,12 @@ def __init__(self, num_basis=8, r_max = 5.0, trainable=True): self.register_buffer("bessel_weights", bessel_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Evaluate Bessel Basis for input x. - - Parameters - ---------- - x : torch.Tensor - Input - """ numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.r_max) - return self.prefactor * (numerator / x.unsqueeze(-1)) - -class _rbf_emb(nn.Module): - ''' - modified: delete cutoff with r - ''' - - def __init__(self, num_rbf, rbound_upper, rbf_trainable=False): - super().__init__() - self.rbound_upper = rbound_upper - self.rbound_lower = 0 - self.num_rbf = num_rbf - self.rbf_trainable = rbf_trainable - self.pi = pi - means, betas = self._initial_params() - - self.register_buffer("means", means) - self.register_buffer("betas", betas) - - def _initial_params(self): - start_value = torch.exp(torch.scalar_tensor(-self.rbound_upper)) - end_value = torch.exp(torch.scalar_tensor(-self.rbound_lower)) - means = torch.linspace(start_value, end_value, self.num_rbf) - betas = torch.tensor([(2 / self.num_rbf * (end_value - start_value)) ** -2] * - self.num_rbf) - return means, betas - - def reset_parameters(self): - means, betas = self._initial_params() - self.means.data.copy_(means) - self.betas.data.copy_(betas) - - def forward(self, dist): - dist = dist.unsqueeze(-1) - rbounds = 0.5 * \ - (torch.cos(dist * self.pi / self.rbound_upper) + 1.0) - rbounds = rbounds * (dist < self.rbound_upper).float() - return rbounds * torch.exp(-self.betas * torch.square((torch.exp(-dist) - self.means))) class NeighborEmb(MessagePassing): - propagate_type = { - 'x': Tensor, - 'norm': Tensor - } + propagate_type = {'x': Tensor, 'norm': Tensor} def __init__(self, hid_dim: int): super(NeighborEmb, self).__init__(aggr='add') @@ -110,13 +92,7 @@ def __init__(self, hid_dim: int): self.hid_dim = hid_dim self.ln_emb = nn.LayerNorm(hid_dim, elementwise_affine=False) - def forward( - self, - z: Tensor, - s: Tensor, - edge_index: Tensor, - embs: Tensor - ) -> Tensor: + def forward(self, z: Tensor, s: Tensor, edge_index: Tensor, embs: Tensor) -> Tensor: s_neighbors = self.ln_emb(self.embedding(z)) s_neighbors = self.propagate(edge_index, x=s_neighbors, norm=embs) s = s + s_neighbors @@ -127,10 +103,7 @@ def message(self, x_j: Tensor, norm: Tensor) -> Tensor: class S_vector(MessagePassing): - propagate_type = { - 'x': Tensor, - 'norm': Tensor - } + propagate_type = {'x': Tensor, 'norm': Tensor} def __init__(self, hid_dim: int): super(S_vector, self).__init__(aggr='add') @@ -140,13 +113,7 @@ def __init__(self, hid_dim: int): nn.LayerNorm(hid_dim, elementwise_affine=False), nn.SiLU()) - def forward( - self, - s: Tensor, - v: Tensor, - edge_index: Tensor, - emb: Tensor - ) -> Tensor: + def forward(self, s: Tensor, v: Tensor, edge_index: Tensor, emb: Tensor) -> Tensor: s = self.lin1(s) emb = emb.unsqueeze(1) * v v = self.propagate(edge_index, x=s, norm=emb) @@ -157,13 +124,10 @@ def message(self, x_j: Tensor, norm: Tensor) -> Tensor: a = norm.view(-1, 3, self.hid_dim) * x_j return a.view(-1, 3 * self.hid_dim) -class EquiMessagePassing(MessagePassing): +class EquiMessagePassing(MessagePassing): propagate_type = { - 'xh': Tensor, - 'vec': Tensor, - 'rbfh_ij': Tensor, - 'r_ij': Tensor + 'xh': Tensor, 'vec': Tensor, 'rbfh_ij': Tensor, 'r_ij': Tensor } def __init__( @@ -177,7 +141,7 @@ def __init__( has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, - complex_type = torch.complex64, + complex_type=torch.complex64, reduce_mode='sum', device=torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu") ): @@ -193,9 +157,12 @@ def __init__( self.hidden_channels_chi = hidden_channels_chi self.scale = nn.Linear(self.hidden_channels, self.hidden_channels_chi * 2) self.num_radial = num_radial + self.dir_proj = nn.Sequential( - nn.Linear(3 * self.hidden_channels + self.num_radial, self.hidden_channels * 3), nn.SiLU(inplace=True), - nn.Linear(self.hidden_channels * 3, self.hidden_channels * 3), ) + nn.Linear(3 * self.hidden_channels + self.num_radial, self.hidden_channels * 3), + nn.SiLU(inplace=True), + nn.Linear(self.hidden_channels * 3, self.hidden_channels * 3) + ) self.x_proj = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), @@ -217,26 +184,26 @@ def __init__( self.dx_layer_norm = nn.LayerNorm(self.chi1) if self.has_norm_before_flag: self.dx_layer_norm = nn.LayerNorm(self.chi1 + self.hidden_channels) + self.dropout = nn.Dropout(p=0.5) self.diachi1 = torch.nn.Parameter(torch.randn((self.chi1), device=self.device)) self.scale2 = nn.Sequential( - nn.Linear(self.chi1, hidden_channels//2), + nn.Linear(self.chi1, hidden_channels // 2), ) self.kernel_real = torch.nn.Parameter(torch.randn((self.head + 1, (self.hidden_channels_chi) // self.head, self.chi2))) self.kernel_imag = torch.nn.Parameter(torch.randn((self.head + 1, (self.hidden_channels_chi) // self.head, self.chi2))) - self.fc_mps = nn.Linear(self.chi1, self.chi1)#.to(torch.cfloat) - self.fc_dx = nn.Linear(self.chi1, hidden_channels)#.to(torch.cfloat) - self.dia = nn.Linear(self.chi1, self.chi1)#.to(torch.cfloat) + self.fc_mps = nn.Linear(self.chi1, self.chi1) + self.fc_dx = nn.Linear(self.chi1, hidden_channels) + self.dia = nn.Linear(self.chi1, self.chi1) self.unitary = torch.nn.Parameter(torch.randn((self.chi1, self.chi1), device=self.device)) self.activation = nn.SiLU() self.inv_sqrt_3 = 1 / math.sqrt(3.0) self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.x_layernorm = nn.LayerNorm(hidden_channels) - + self.reset_parameters() def reset_parameters(self): @@ -247,7 +214,6 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.rbf_proj.weight) self.rbf_proj.bias.data.fill_(0) self.x_layernorm.reset_parameters() - nn.init.xavier_uniform_(self.dir_proj[0].weight) self.dir_proj[0].bias.data.fill_(0) @@ -265,16 +231,16 @@ def forward( rope: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Tensor]: if rope is not None: - real, imag = torch.split(x, [self.hidden_channels//2, self.hidden_channels//2], dim=-1) + real, imag = torch.split(x, [self.hidden_channels // 2, self.hidden_channels // 2], dim=-1) dy_pre = torch.complex(real=real, imag=imag) - dy_pre = dy_pre* rope + dy_pre = dy_pre * rope x = torch.cat([dy_pre.real, dy_pre.imag], dim=-1) + xh = self.x_proj(self.x_layernorm(x)) - rbfh = self.rbf_proj(edge_rbf) weight = self.dir_proj(weight) rbfh = rbfh * weight - # propagate_type: (xh: Tensor, vec: Tensor, rbfh_ij: Tensor, r_ij: Tensor) + dx, dvec = self.propagate( edge_index, xh=xh, @@ -282,8 +248,8 @@ def forward( rbfh_ij=rbfh, r_ij=edge_vector, size=None, - # rotation = unitary, ) + if self.has_norm_before_flag: dx = self.dx_layer_norm(dx) @@ -293,7 +259,6 @@ def forward( dx = self.dx_layer_norm(dx) dx = self.scale2(dx) - dx = torch.complex(torch.cos(dx), torch.sin(dx)) return dx, dy, dvec @@ -309,21 +274,23 @@ def message(self, xh_j, vec_j, rbfh_ij, r_ij): real = self.dropout(real) imagine = self.dropout(imagine) - # complex invariant quantum state phi = torch.complex(real, imagine) q = phi - a = torch.ones(q.shape[0], 1, (self.hidden_channels_chi) // self.head, device=self.device, dtype= self.complex_type) + a = torch.ones(q.shape[0], 1, (self.hidden_channels_chi) // self.head, device=self.device, dtype=self.complex_type) kernel = (torch.complex(self.kernel_real, self.kernel_imag) / math.sqrt((self.hidden_channels) // self.head)).expand(q.shape[0], -1, -1, -1) + equation = 'ijl, ijlk->ik' - conv = torch.einsum(equation, torch.cat([a, q], dim=1), kernel.to( self.complex_type)) + conv = torch.einsum(equation, torch.cat([a, q], dim=1), kernel.to(self.complex_type)) a = 1.0 * self.activation(self.diagonal(rbfh_ij)) b = a.unsqueeze(-1) * self.diachi1.unsqueeze(0).unsqueeze(0) + torch.ones(kernel.shape[0], self.chi2, self.chi1, device=self.device) dia = self.dia(b) + equation = 'ik,ikl->il' kernel = torch.einsum(equation, conv, dia.to(self.complex_type)) - kernel_real,kernel_imag = kernel.real,kernel.imag - kernel_real,kernel_imag = self.fc_mps(kernel_real),self.fc_mps(kernel_imag) + kernel_real, kernel_imag = kernel.real, kernel.imag + kernel_real, kernel_imag = self.fc_mps(kernel_real), self.fc_mps(kernel_imag) kernel = torch.angle(torch.complex(kernel_real, kernel_imag)) + agg = torch.cat([kernel, x], dim=-1) vec = vec_j * xh2.unsqueeze(1) + xh3.unsqueeze(1) * r_ij.unsqueeze(2) vec = vec * self.inv_sqrt_h @@ -339,7 +306,7 @@ def aggregate( ) -> Tuple[torch.Tensor, torch.Tensor]: x, vec = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.reduce_mode) - vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) + vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size, reduce='sum') return x, vec def update( @@ -361,12 +328,10 @@ def __init__(self, hidden_channels): nn.Linear(hidden_channels * 2, hidden_channels), nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3) - ) self.inv_sqrt_2 = 1 / math.sqrt(2.0) self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.reset_parameters() def reset_parameters(self): @@ -378,50 +343,29 @@ def reset_parameters(self): def forward(self, x, vec): vec = self.vec_proj(vec) - vec1, vec2 = torch.split( - vec, self.hidden_channels, dim=-1 - ) + vec1, vec2 = torch.split(vec, self.hidden_channels, dim=-1) - scalar = torch.norm(vec1, dim=-2, p=1) + scalar = torch.sum(vec1**2, dim=-2) vec_dot = (vec1 * vec2).sum(dim=1) vec_dot = vec_dot * self.inv_sqrt_h - x_vec_h = self.xvec_proj( - torch.cat( - [x, scalar], dim=-1 - ) - ) - xvec1, xvec2, xvec3 = torch.split( - x_vec_h, self.hidden_channels, dim=-1 - ) + x_vec_h = self.xvec_proj(torch.cat([x, scalar], dim=-1)) + xvec1, xvec2, xvec3 = torch.split(x_vec_h, self.hidden_channels, dim=-1) dx = xvec1 + xvec2 + vec_dot dx = dx * self.inv_sqrt_2 dvec = xvec3.unsqueeze(1) * vec2 - return dx, dvec -class aggregate_pos(MessagePassing): - - def __init__(self, aggr='mean'): - super(aggregate_pos, self).__init__(aggr=aggr) - - def forward(self, vector, edge_index): - v = self.propagate(edge_index, x=vector) - - return v - - class AlphaNet(nn.Module): - def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")): super(AlphaNet, self).__init__() self.device = device self.complex_type = torch.complex64 if config.dtype == "32" else torch.complex128 - self.eps = config.eps + self.eps = 1e-9 self.num_layers = config.num_layers self.hidden_channels = config.hidden_channels self.a = nn.Parameter(torch.ones(108) * config.a) @@ -434,6 +378,7 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl self.num_targets = config.output_dim if config.output_dim != 0 else 1 self.compute_forces = config.compute_forces self.compute_stress = config.compute_stress + self.z_emb_ln = nn.LayerNorm(config.hidden_channels, elementwise_affine=False) self.z_emb = Embedding(95, config.hidden_channels) self.kernel1 = torch.nn.Parameter(torch.randn((config.hidden_channels, self.chi1 * 2), device=self.device)) @@ -455,11 +400,10 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl self.kernels_real = [] self.kernels_imag = [] self.zbl = config.zbl + if self.zbl: - M = 8 self.register_buffer('fzbl_w', torch.tensor([0.187,0.3769,0.189,0.081,0.003,0.037,0.0546,0.0715], dtype=torch.get_default_dtype())) self.register_buffer('fzbl_b', torch.tensor([3.20,1.10,0.102,0.958,1.28,1.14,1.69,5], dtype=torch.get_default_dtype())) - # normalize weights just in case with torch.no_grad(): w = getattr(self, 'fzbl_w') w = w.clamp(min=0.0) @@ -468,8 +412,6 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl self.register_buffer('fzbl_gamma', torch.tensor(1.001, dtype=torch.get_default_dtype())) self.register_buffer('fzbl_alpha', torch.tensor(0.6032, dtype=torch.get_default_dtype())) - - # physics constants self.register_buffer('fzbl_E2', torch.tensor(14.399645478425, dtype=torch.get_default_dtype())) # eV·Å self.register_buffer('fzbl_A0', torch.tensor(0.529177210903, dtype=torch.get_default_dtype())) # Å @@ -485,7 +427,7 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl has_norm_before_flag=config.has_norm_before_flag, has_norm_after_flag=config.has_norm_after_flag, hidden_channels_chi=config.hidden_channels_chi, - complex_type = self.complex_type, + complex_type=self.complex_type, device=device, reduce_mode=config.reduce_mode ) @@ -521,14 +463,13 @@ def reset_parameters(self): layer.reset_parameters() def forward(self, data: GraphData, prefix: str): - pos = data.pos batch = data.batch z = data.z.long() edge_index = data.edge_index - dist = data.edge_attr vecs = data.edge_vec + dist = torch.linalg.norm(vecs, dim=1) z_emb = self.z_emb_ln(self.z_emb(z)) radial_emb = self.radial_emb(dist) radial_hidden = self.radial_lin(radial_emb) @@ -542,18 +483,19 @@ def forward(self, data: GraphData, prefix: str): i = edge_index[1] edge_diff = vecs edge_diff = edge_diff / (dist.unsqueeze(1) + self.eps) - mean = scatter(pos[edge_index[0]], edge_index[1], reduce='mean', dim=0) - edge_cross = torch.cross(pos[i]-mean[i], pos[j]-mean[i]) + edge_vec_mean = scatter(vecs, i, reduce='mean', dim=0) + edge_cross = torch.cross(vecs, edge_vec_mean[i]) edge_vertical = torch.cross(edge_diff, edge_cross) edge_frame = torch.cat((edge_diff.unsqueeze(-1), edge_cross.unsqueeze(-1), edge_vertical.unsqueeze(-1)), dim=-1) S_i_j = self.S_vector(s, edge_diff.unsqueeze(-1), edge_index, radial_hidden) + scalrization1 = torch.sum(S_i_j[i].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) scalrization2 = torch.sum(S_i_j[j].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) - scalrization1[:, 1, :] = torch.abs(scalrization1[:, 1, :].clone()) - scalrization2[:, 1, :] = torch.abs(scalrization2[:, 1, :].clone()) - + scalrization1[:, 1, :] = torch.square(scalrization1[:, 1, :].clone()) + scalrization2[:, 1, :] = torch.square(scalrization2[:, 1, :].clone()) + scalar3 = (self.lin(torch.permute(scalrization1, (0, 2, 1))) + torch.permute(scalrization1, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt(self.hidden_channels) scalar4 = (self.lin(torch.permute(scalrization2, (0, 2, 1))) + @@ -582,73 +524,27 @@ def forward(self, data: GraphData, prefix: str): equation = 'ikl,bi,bl->bk' kerneli = torch.complex(kernel_real, kernel_imag) quantum = torch.einsum(equation, kerneli, s.to(self.complex_type), quantum) - quantum = quantum / quantum.abs().to(self.complex_type) + quantum = quantum / (self.eps + quantum.abs().to(self.complex_type)) ds, dvec = fte(s, vec) s = s + ds vec = vec + dvec s = self.last_layer(s) + self.last_layer_quantum(torch.cat([quantum.real, quantum.imag], dim=-1)) / self.chi1 - V_graph = 0 - if self.zbl: - r_e = dist - Z_j = z[j] - Z_i = z[i] - - # load constants (buffers) and cast to pos dtype/device - w = self.fzbl_w.to(device=s.device, dtype=s.dtype) # (M,) - b = self.fzbl_b.to(device=s.device, dtype=s.dtype) # (M,) - gamma = self.fzbl_gamma.to(device=s.device, dtype=s.dtype) - alpha = self.fzbl_alpha.to(device=s.device, dtype=s.dtype) - E2 = self.fzbl_E2.to(device=s.device, dtype=s.dtype) - A0 = self.fzbl_A0.to(device=s.device, dtype=s.dtype) - - # compute screening length a per edge: a = gamma * 0.8854 * a0 / (Z1^alpha + Z2^alpha) - denom = torch.pow(Z_j, alpha) + torch.pow(Z_i, alpha) # (E,) - denom = torch.clamp(denom, min=1e-12) - a_vals = gamma * 0.8854 * A0 / denom # (E,) - x = r_e / a_vals # (E,) - - # compute phi(x) = sum_i w_i * exp(-b_i * x) (vectorized) - # exp(- x[:,None] * b[None,:]) -> (E, M) - exp_terms = torch.exp(- x.unsqueeze(1) * b.unsqueeze(0)) # (E, M) - phi_vals = exp_terms.matmul(w) # (E,) - - # pair potential per edge: V_e = Z1*Z2 * E2 * phi / r - V_edge = (Z_j * Z_i * E2) * (phi_vals / r_e) # (E,) - r_cut = 1.0 # you can make this self.fzbl_rcut buffer if you want configurable value - - # compute taper coefficient: cosine cutoff (smooth) - # for r in [0, r_cut]: c = 0.5*(cos(pi * r / r_cut) + 1) - # for r >= r_cut: c = 0 - # for safety, clamp r/r_cut in [0, 1] - xrc = (r_e / r_cut).clamp(min=0.0, max=1.0) # (E,) - # cosine taper - c = 0.5 * (torch.cos(torch.pi * xrc) + 1.0) # (E,) - # enforce zero beyond r_cut explicitly (cos already gives 0 at x=1 but clamp keeps numeric safe) - c = torch.where(r_e >= r_cut, torch.zeros_like(c), c) - - # apply taper to edge potential - V_edge = V_edge * c - # aggregate edge energies to graph-level (use receiver node's batch index) - # use torch_scatter.scatter_add (or your existing scatter) to sum per-graph - - graph_idx = batch[i] # map receiver node -> graph index (E,) - V_graph = scatter_add(V_edge, graph_idx, dim=0) / 2.0 + if s.dim() == 2: s = (self.a[z].unsqueeze(1) * s + self.b[z].unsqueeze(1)) elif s.dim() == 1: s = (self.a[z] * s + self.b[z]).unsqueeze(1) else: raise ValueError(f"Unexpected shape of s: {s.shape}") - #print(s.shape, V_graph.shape, batch.shape) - s = scatter(s, batch, dim=0, reduce=self.readout).squeeze()#+ V_graph - #print(s.shape) + + s = scatter(s, batch, dim=0, reduce=self.readout).squeeze() + if self.use_sigmoid: s = torch.sigmoid((s - 0.5) * 5) - #return s, None, None - if self.compute_forces and self.compute_stress: + if self.compute_forces and self.compute_stress: if data.displacement is not None: stress, forces = self.cal_stress_and_force(s, pos, data.displacement, data.cell, prefix) stress = stress.view(-1, 3) @@ -659,10 +555,10 @@ def forward(self, data: GraphData, prefix: str): elif self.compute_forces: forces = self.cal_forces(s, pos, prefix) return s, forces, None + return s, None, None def cal_forces(self, energy, positions, prefix: str = 'infer'): - graph = (prefix == "train") grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones_like(energy)]) forces = torch.autograd.grad( @@ -676,9 +572,9 @@ def cal_forces(self, energy, positions, prefix: str = 'infer'): assert forces is not None, "Gradient should not be None" return -forces - def cal_stress_and_force(self, energy: Tensor,positions: Tensor, displacement: Optional[Tensor], cell: Tensor, prefix: str) -> Tuple[Tensor, Tensor]: + def cal_stress_and_force(self, energy: Tensor, positions: Tensor, displacement: Optional[Tensor], cell: Tensor, prefix: str) -> Tuple[Tensor, Tensor]: if displacement is None: - raise ValueError("displacement cannot be None for stress calculation") + raise ValueError("displacement cannot be None for stress calculation") graph = (prefix == "train") grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones_like(energy)]) output = torch.autograd.grad( @@ -694,10 +590,7 @@ def cal_stress_and_force(self, energy: Tensor,positions: Tensor, displacement: O volume = torch.abs(torch.linalg.det(cell)) volume_expanded = volume.reshape(-1, 1, 1) stress = virial / volume_expanded - force =output[1] + force = output[1] assert force is not None, "Forces tensor should not be None" - return stress, -force - - - + return stress, -force \ No newline at end of file diff --git a/alphanet/models/graph.py b/alphanet/models/graph.py index 8e18476..69cfcfa 100644 --- a/alphanet/models/graph.py +++ b/alphanet/models/graph.py @@ -468,7 +468,7 @@ def _process_positions_and_edges( if use_pbc and cell is not None: edge_index, cell_offsets, neighbors = radius_graph_pbc( - pos, natoms, cell, cutoff, max_num_neighbors_threshold=50, precision=precision + pos, natoms, cell, cutoff, max_num_neighbors_threshold=500, precision=precision ) new_pos = pos new_z = z diff --git a/alphanet/models/model.py b/alphanet/models/model.py index 068e79f..325b8a6 100644 --- a/alphanet/models/model.py +++ b/alphanet/models/model.py @@ -24,7 +24,6 @@ def forward(self, natoms: Tensor, cell: Optional[Tensor] = None, prefix: str = 'infer'): - processed_data = process_positions_and_edges( pos=pos, z=z, diff --git a/alphanet/models/zbl.py b/alphanet/models/zbl.py new file mode 100644 index 0000000..4304bd2 --- /dev/null +++ b/alphanet/models/zbl.py @@ -0,0 +1,76 @@ +# 文件路径: AlphaNet-lammps/alphanet/models/zbl.py +import torch +from torch import nn +import math + +class ZBLPotential(nn.Module): + def __init__(self, config): + super().__init__() + # 默认参数 (通用拟合值) + default_w = [0.187, 0.3769, 0.189, 0.081, 0.003, 0.037, 0.0546, 0.0715] + default_b = [3.20, 1.10, 0.102, 0.958, 1.28, 1.14, 1.69, 5.0] + + # 从配置读取,支持微调 ZBL 参数 + w = getattr(config, 'zbl_w', default_w) + if w is None: w = default_w + + b = getattr(config, 'zbl_b', default_b) + if b is None: b = default_b + + gamma = getattr(config, 'zbl_gamma', 1.001) + alpha = getattr(config, 'zbl_alpha', 0.6032) + + self.register_buffer('fzbl_w', torch.tensor(w, dtype=torch.get_default_dtype())) + self.register_buffer('fzbl_b', torch.tensor(b, dtype=torch.get_default_dtype())) + + # 归一化权重 + with torch.no_grad(): + w_tensor = self.fzbl_w + w_tensor = w_tensor.clamp(min=0.0) + w_tensor = w_tensor / (w_tensor.sum() + 1e-12) + self.fzbl_w.copy_(w_tensor) + + self.register_buffer('fzbl_gamma', torch.tensor(gamma, dtype=torch.get_default_dtype())) + self.register_buffer('fzbl_alpha', torch.tensor(alpha, dtype=torch.get_default_dtype())) + + # 物理常数 + self.register_buffer('fzbl_E2', torch.tensor(14.399645478425, dtype=torch.get_default_dtype())) # eV·Å + self.register_buffer('fzbl_A0', torch.tensor(0.529177210903, dtype=torch.get_default_dtype())) # Å + + # 平滑截断参数 + self.r_cut = 1.0 + + def forward(self, dist: torch.Tensor, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor: + """ + Pair ZBL Energy V_ij + """ + r_e = dist + + # 确保计算精度一致 + dtype = r_e.dtype + w = self.fzbl_w.to(dtype=dtype) + b = self.fzbl_b.to(dtype=dtype) + gamma = self.fzbl_gamma.to(dtype=dtype) + alpha = self.fzbl_alpha.to(dtype=dtype) + E2 = self.fzbl_E2.to(dtype=dtype) + A0 = self.fzbl_A0.to(dtype=dtype) + + denom = torch.pow(z_j, alpha) + torch.pow(z_i, alpha) + denom = torch.clamp(denom, min=1e-12) + a_vals = gamma * 0.8854 * A0 / denom + x = r_e / a_vals + + # V(r) = (Z1*Z2*e^2/r) * phi(x) + # phi(x) = sum(w * exp(-b*x)) + exp_terms = torch.exp(-x.unsqueeze(1) * b.unsqueeze(0)) # [Edges, M] + phi_vals = exp_terms.matmul(w) # [Edges] + + V_edge = (z_j * z_i * E2) * (phi_vals / r_e) + + # Cutoff smoothing (0 to r_cut) + + xrc = (r_e / self.r_cut).clamp(min=0.0, max=1.0) + c = 0.5 * (torch.cos(math.pi * xrc) + 1.0) + c = torch.where(r_e >= self.r_cut, torch.zeros_like(c), c) + + return V_edge * c \ No newline at end of file diff --git a/alphanet/models/zbl_jax.py b/alphanet/models/zbl_jax.py new file mode 100644 index 0000000..041d053 --- /dev/null +++ b/alphanet/models/zbl_jax.py @@ -0,0 +1,48 @@ +import jax.numpy as jnp +import jax + +def zbl_interaction(dist, z_i, z_j, w, b, gamma, alpha, E2, A0, r_cut=1.0): + """ + 计算 ZBL Pair Potential (JAX functional implementation) + """ + r_e = dist + + # 防止除零 + denom = jnp.power(z_j, alpha) + jnp.power(z_i, alpha) + denom = jnp.clip(denom, a_min=1e-12) + + a_vals = gamma * 0.8854 * A0 / denom + x = r_e / a_vals + + # phi(x) = sum(w * exp(-b*x)) + # w: (M,), b: (M,), x: (Edges,) + # exp(- x[:, None] * b[None, :]) -> (Edges, M) + exp_terms = jnp.exp(-x[:, jnp.newaxis] * b[jnp.newaxis, :]) + phi_vals = jnp.dot(exp_terms, w) + + V_edge = (z_j * z_i * E2) * (phi_vals / r_e) + + # Cutoff smoothing + xrc = jnp.clip(r_e / r_cut, 0.0, 1.0) + c = 0.5 * (jnp.cos(jnp.pi * xrc) + 1.0) + c = jnp.where(r_e >= r_cut, jnp.zeros_like(c), c) + + return V_edge * c + +def get_default_zbl_params(dtype=jnp.float32): + fzbl_w = jnp.array([0.187, 0.3769, 0.189, 0.081, 0.003, 0.037, 0.0546, 0.0715], dtype=dtype) + # Normalize + fzbl_w = jnp.clip(fzbl_w, a_min=0.0) + fzbl_w = fzbl_w / (jnp.sum(fzbl_w) + 1e-12) + + fzbl_b = jnp.array([3.20, 1.10, 0.102, 0.958, 1.28, 1.14, 1.69, 5.0], dtype=dtype) + + params = { + "w": fzbl_w, + "b": fzbl_b, + "gamma": jnp.array(1.001, dtype=dtype), + "alpha": jnp.array(0.6032, dtype=dtype), + "E2": jnp.array(14.399645478425, dtype=dtype), + "A0": jnp.array(0.529177210903, dtype=dtype) + } + return params \ No newline at end of file diff --git a/alphanet/mul_trainer.py b/alphanet/mul_trainer.py index a3085f1..6dc6627 100644 --- a/alphanet/mul_trainer.py +++ b/alphanet/mul_trainer.py @@ -1,5 +1,5 @@ import torch -from torch.optim import Adam, AdamW +from torch.optim import Adam, AdamW, RAdam from torch_geometric.data import DataLoader from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR import pytorch_lightning as pl @@ -177,7 +177,19 @@ def test_step(self, batch, batch_idx): return loss def configure_optimizers(self): - optimizer = Adam(self.parameters(), lr=self.config.train.lr, weight_decay=self.config.train.weight_decay) + + opt_name = self.config.train.optimizer.lower() + lr = self.config.train.lr + weight_decay = self.config.train.weight_decay + + if opt_name == 'adam': + optimizer = Adam(self.parameters(), lr=lr, weight_decay=weight_decay) + elif opt_name == 'adamw': + optimizer = AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) + elif opt_name == 'radam': + optimizer = RAdam(self.parameters(), lr=lr, weight_decay=weight_decay) + else: + raise ValueError(f"Unknown optimizer: {opt_name}") if self.config.train.scheduler == 'steplr': scheduler = StepLR(optimizer, step_size=self.config.train.lr_decay_step_size, gamma=self.config.train.lr_decay_factor) diff --git a/alphanet/train.py b/alphanet/train.py index ea32865..06fd768 100644 --- a/alphanet/train.py +++ b/alphanet/train.py @@ -5,45 +5,89 @@ from alphanet.models.model import AlphaNetWrapper from alphanet.mul_trainer import Trainer -def run_training(config1,config2): - - train_dataset, valid_dataset, test_dataset = get_pic_datasets(root='dataset/', name=config1.dataset_name,config = config1) +def run_training(config1, runtime_config): + + train_dataset, valid_dataset, test_dataset = get_pic_datasets( + root='dataset/', + name=config1.dataset_name, + config=config1 + ) + force_std = torch.std(train_dataset.data.force).item() - - energy_peratom = torch.sum(train_dataset.data.y).item()/torch.sum(train_dataset.data.natoms).item() + if hasattr(train_dataset.data, 'y') and train_dataset.data.y is not None: + energy_peratom = torch.sum(train_dataset.data.y).item() / torch.sum(train_dataset.data.natoms).item() + else: + energy_peratom = 0.0 + config1.a = force_std config1.b = energy_peratom - #print(config1.a, config1.b) + model = AlphaNetWrapper(config1) - #print(model.model.a, model.model.b) + if config1.dtype == "64": - model = model.double() - #strategy = DDPStrategy(num_nodes=config["hardware"]["num_nodes"]) if config["hardware"]["num_nodes"] > 1 else "auto" + model = model.double() + + + if runtime_config.get("finetune_path"): + ft_path = runtime_config["finetune_path"] + print(f"🔨 Finetuning mode: Loading weights from {ft_path}...") + + + try: + ckpt = torch.load(ft_path, map_location='cpu') + except FileNotFoundError: + raise FileNotFoundError(f"Finetune checkpoint not found at: {ft_path}") + + + if isinstance(ckpt, dict) and 'state_dict' in ckpt: + state_dict = ckpt['state_dict'] + else: + state_dict = ckpt + new_state_dict = state_dict + + missing, unexpected = model.load_state_dict(new_state_dict, strict=False) + + if len(missing) > 0: + print(f" Warning: Missing keys ({len(missing)}): {missing[:3]} ...") + if len(unexpected) > 0: + print(f" Warning: Unexpected keys ({len(unexpected)}): {unexpected[:3]} ...") + + print(" ✅ Weights loaded successfully. Optimizer states reset for finetuning.") + else: + print("🆕 Training from scratch (random initialization).") + + checkpoint_callback = ModelCheckpoint( dirpath=config1.train.save_dir, filename='{epoch}-{val_loss:.4f}-{val_energy_loss:.4f}-{val_force_loss:.4f}', save_top_k=-1, - every_n_epochs=1, - save_on_train_epoch_end=True, + every_n_epochs=1, + save_on_train_epoch_end=True, monitor='val_loss', mode='min' ) + trainer = pl.Trainer( - devices=config2["num_devices"], - num_nodes=config2["num_nodes"], - strategy='ddp_find_unused_parameters_true', - accelerator="gpu" if config2["num_devices"] > 0 else "cpu", + devices=runtime_config["num_devices"], + num_nodes=runtime_config["num_nodes"], + strategy='ddp_find_unused_parameters_true', + accelerator="gpu" if runtime_config["num_devices"] > 0 and torch.cuda.is_available() else "cpu", max_epochs=config1.epochs, callbacks=[checkpoint_callback], enable_checkpointing=True, - gradient_clip_val=0.5, + gradient_clip_val=0.1, default_root_dir=config1.train.save_dir, accumulate_grad_batches=config1.accumulation_steps, - limit_val_batches=100, ) + + pl_module = Trainer(config1, model, train_dataset, valid_dataset, test_dataset) - model = Trainer(config1, model, train_dataset, valid_dataset, test_dataset) - trainer.fit(model, ckpt_path=config2["ckpt_path"] if config2["resume"] else None) + ckpt_path_arg = runtime_config["ckpt_path"] if runtime_config["resume"] else None + + if runtime_config["resume"] and ckpt_path_arg: + print(f"🔄 Resuming training from checkpoint: {ckpt_path_arg}") + + trainer.fit(pl_module, ckpt_path=ckpt_path_arg) \ No newline at end of file diff --git a/example.json b/example.json deleted file mode 100644 index 5c0511c..0000000 --- a/example.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "t1", - "target": "energy_force", - "train_dataset": "t1", - "train_size": 2000, - "valid_dataset": "t1", - "valid_size": 2000, - "test_dataset": "t1", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 3, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 256, - "cutoff": 4, - "num_radial": 8 - }, - "train": { - "epochs": 200, - "batch_size": 24, - "accumulation_steps": 1, - "vt_batch_size": 24, - "lr": 0.0001, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "t1", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/example_for_help.json b/example_for_help.json deleted file mode 100644 index 43022be..0000000 --- a/example_for_help.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "data": { - "root": "dataset/", // Root directory of the dataset - "dataset_name": "t2", // Dataset name (Just a name) - "target": "energy_force", // Task: predict energy and forces - "train_dataset": "t2", // Training dataset name (Exactly the name of dataset folder) - "train_size": 2000, // Number of training samples - "valid_dataset": "t2_valid", // Validation dataset name (Exactly the name of dataset folder) - "valid_size": 2000, // Number of validation samples - "test_dataset": "t2_valid", // Test dataset name (Exactly the name of dataset folder) - "seed": 42 // Random seed for reproducibility - }, - "model": { - "name": "Alphanet", // Model name - "num_layers": 3, // Number of layers - "num_targets": 1, // Number of target variables - "output_dim": 1, // Output dimension - "compute_forces": true, // Predict forces - "compute_stress": true, // Predict stress - "use_pbc": true, // Use periodic boundary conditions - "hidden_channels": 128, // Hidden layer size - "cutoff": 4, // Cutoff radius for interactions - "num_radial": 8 // Number of radial basis functions - }, - "train": { - "epochs": 200, // Total training epochs - "batch_size": 4, // Batch size - "accumulation_steps": 1, // Gradient accumulation steps - "vt_batch_size": 4, // Batch size for validation/testing - "lr": 0.0001, // Initial learning rate - "lr_decay_factor": 0.8, // Learning rate decay factor - "lr_decay_step_size": 80, // Step size for learning rate decay - "weight_decay": 0, // L2 regularization weight - "save_dir": "t2_15", // Model save directory - "log_dir": "", // Log directory (empty if none) - "disable_tqdm": false, // Disable progress bar - "scheduler": "cosineannealinglr",// Learning rate scheduler - "device": "cuda", // Device for training (GPU) - "energy_loss": "mae", // Loss function for energy (MAE) - "force_loss": "mae", // Loss function for forces (MAE) - "energy_metric": "mae", // Metric for energy evaluation - "force_metric": "mae", // Metric for force evaluation - "energy_coef": 4.0, // Coefficient for energy loss - "force_coef": 100.0, // Coefficient for force loss - "eval_steps": 1 // Evaluation step interval - } -} - diff --git a/lmp_requirements.txt b/lmp_requirements.txt new file mode 100644 index 0000000..3eb9aa4 --- /dev/null +++ b/lmp_requirements.txt @@ -0,0 +1,6 @@ +cython + +pytest +attrs +jinja2 +urllib3 diff --git a/mliap_lammps.md b/mliap_lammps.md new file mode 100644 index 0000000..e857bf2 --- /dev/null +++ b/mliap_lammps.md @@ -0,0 +1,159 @@ +# LAMMPS Installation Guide: ML-IAP with KOKKOS (CUDA) Support + +This guide details the process for building LAMMPS with the Machine Learning Interatomic Potential (`ML-IAP`) package, enabled with Python bindings and KOKKOS acceleration for NVIDIA GPUs. + +## Prerequisites +* **Git** & **CMake** +* **C++ Compiler** (compatible with your MPI/CUDA version) +* **MPI Implementation** (e.g., OpenMPI, MPICH) +* **CUDA Toolkit** (version 11.x or 12.x) +* **Python 3.x** + +--- + +## 1. Get the Source Code +Clone the repository and check out the specific commit hash used for this build to ensure reproducibility. + +```bash +git clone https://github.com/lammps/lammps.git +cd lammps + +# Checkout specific commit for stability/reproducibility +git checkout ccca772 + +``` + +## 2. Prepare the Build Environment + +We will use an out-of-source build directory to keep the source tree clean. We will also start with the `kokkos-cuda` CMake preset. + +```bash +mkdir build-mliap +cd build-mliap + +# Copy the KOKKOS CUDA preset to the build directory +cp ../cmake/presets/kokkos-cuda.cmake ./ + +``` + +### ⚠️ Important Configuration Step + +Before generating the build files, you must edit `kokkos-cuda.cmake` to match your specific GPU architecture. + +1. Open `kokkos-cuda.cmake` in a text editor. +2. Locate the architecture flag (e.g., `-DKokkos_ARCH_...`). +3. Change it to match your GPU (e.g., `Kokkos_ARCH_VOLTA70`, `Kokkos_ARCH_AMPERE80`, `Kokkos_ARCH_HOPPER90`). +* *Reference:* [LAMMPS KOKKOS Build Options](https://docs.lammps.org/Build_extras.html#kokkos) + + + +--- + +## 3. Configure and Compile + +Run CMake to configure the build with ML-IAP, SNAP, and Python support enabled. + +```bash +cmake -C kokkos-cuda.cmake \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_INSTALL_PREFIX=$(pwd) \ + -D BUILD_MPI=ON \ + -D PKG_ML-IAP=ON \ + -D PKG_ML-SNAP=ON \ + -D MLIAP_ENABLE_PYTHON=ON \ + -D PKG_PYTHON=ON \ + -D BUILD_SHARED_LIBS=ON \ + ../cmake + +``` + +**Key Flags Explained:** + +* `PKG_ML-IAP=ON`: Enables the Machine Learning Interatomic Potentials package. +* `MLIAP_ENABLE_PYTHON=ON`: Allows ML-IAP to call Python functions (essential for PyTorch/PyG models). +* `BUILD_SHARED_LIBS=ON`: Builds LAMMPS as a shared library (`.so`), required for the Python module. + +### Compile + +Compile the code using multiple cores (adjust `-j 8` based on your CPU cores). + +```bash +make -j 8 + +``` + +--- + +## 4. Python Environment Setup + +Install the LAMMPS Python interface and the necessary dependencies. + +```bash +# Install the lammps python module into your current environment +make install-python + +# Install dependencies for your ML model +cd ../../ +pip install -r lmp_requirements.txt + +# Install CuPy (Ensure the version matches your CUDA version) +# For CUDA 12.x: +pip install cupy-cuda12x +# For CUDA 11.x, use: pip install cupy-cuda11x + +``` + +--- + +## 5. Running LAMMPS + +Below are the commands to run LAMMPS using the KOKKOS accelerator package on GPUs. +### Convert the checkpoint: +```bash +python alphanet/create_lammps_model.py \ + --config ./pretrained/OMA/oma.json \ + --checkpoint ./pretrained/OMA/alex_0410.ckpt \ + --output ./alphanet_lammps.pt \ + --dtype float64 \ + --device cpu \ + +``` +### Input file: + +Necessary settings: +```bash +units metal +atom_style atomic +newton on +pair_style mliap unified your_converted.pt 0 +``` + +### Single GPU Execution + +Run on 1 GPU without MPI. + +```bash +lmp -k on g 1 -sf kk -pk kokkos newton on neigh half gpu/aware on -in test.in + +``` + +### Multi-GPU Execution + +Run on 2 GPUs using MPI. + +```bash +mpirun -np 2 lmp -k on g 2 -sf kk -pk kokkos newton on neigh half gpu/aware on -in sl.in + +``` + +### Runtime Flags Breakdown + +* `-k on g X`: Enable KOKKOS and use **X** GPUs per node. +* `-sf kk`: **Suffix KOKKOS**. Automatically appends `/kk` to styles in the input script (e.g., `pair_style` becomes `pair_style/kk`). +* `-pk kokkos`: Modifies global KOKKOS parameters: +* `newton on`: Turns on Newton's 3rd law optimizations (often faster for GPUs). +* `neigh half`: Uses a half-neighbor list (often more efficient on GPUs). +* `gpu/aware on`: Optimizes MPI communication if using CUDA-aware MPI. + + + diff --git a/mul_train.py b/mul_train.py deleted file mode 100644 index 8457a23..0000000 --- a/mul_train.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from alphanet.data import get_pic_datasets -from alphanet.models.model import AlphaNetWrapper -from alphanet.config import All_Config -from alphanet.mul_trainer import Trainer -import os -def main(): - config = All_Config().from_json("OC2M-train.json") - train_dataset, valid_dataset, test_dataset = train_dataset, valid_dataset, test_dataset = get_pic_datasets(root='dataset/', name=config.dataset_name,config = config) - force_std = torch.std(train_dataset.data.force).item() - ENERGY_MEAN_TOTAL = 0 - FORCE_MEAN_TOTAL = 0 - NUM_ATOM = None - - for data in valid_dataset: - energy = data.y - force = data.force - NUM_ATOM = force.size()[0] - energy_mean = energy / NUM_ATOM - ENERGY_MEAN_TOTAL += energy_mean - - ENERGY_MEAN_TOTAL /= len(train_dataset) - - config.a = force_std - config.b = ENERGY_MEAN_TOTAL - - model = AlphaNetWrapper(config) - - checkpoint_callback = ModelCheckpoint( - dirpath=config.train.save_dir, - filename='{epoch}-{val_loss:.4f}-{val_energy_mae:.4f}-{val_force_mae:.4f}', - save_top_k=-1, - every_n_epochs=1, - save_on_train_epoch_end=True, - monitor='val_loss', - mode='min' - ) - - early_stopping_callback = EarlyStopping( - monitor='val_loss', - patience= 50, - mode='min' - ) - - trainer = pl.Trainer( - devices=3, - num_nodes=1, - limit_train_batches=40000, - accelerator='auto', - #inference_mode=False, - - strategy='ddp_find_unused_parameters_true', - max_epochs=config.train.epochs, - callbacks=[checkpoint_callback, early_stopping_callback], - default_root_dir=config.train.save_dir, - logger=pl.loggers.TensorBoardLogger(config.train.log_dir), - gradient_clip_val=0.5, - accumulate_grad_batches=config.train.accumulation_steps - ) - - model = Trainer(config, model, train_dataset, valid_dataset, test_dataset) - trainer.fit(model)#, ckpt_path = ckpt) - trainer.test() - -if __name__ == '__main__': - main() diff --git a/pretrained/AQCAT25/README.md b/pretrained/AQCAT25/README.md deleted file mode 100644 index 3a7364e..0000000 --- a/pretrained/AQCAT25/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# AlphaNet-AQCAT25 - -A model trained on the [AQCAT25](https://www.sandboxaq.com/aqcat25), mainly used for surfaces adsorbtion and reactions. The model is trained on the total energies and forces of trainiing set and slabs data. - -## Model Details - -* **Parameters:** Approximately 6.9M - -## Access the Model - -The following resources are available in the `pretrained_models/AQCAT25` path: -* **Model Configuration:** `aqcat.json` -* **Model state\_dict:** Pre-trained weights `aqcat_1021.ckpt` - - -## Performance -| Mae | Value | Unit/Description | -| :--- | :--- | :--- | -| test_id | 0.010,0.088 | eV/atom , eV/$ \AA $| -| test_ood_ads | 0.010,0.082 | eV/atom , eV/$ \AA $ | -| test_ood_both | 0.024, 0.097 | eV/atom , eV/$ \AA $ | -| test_ood_mat | 0.0186, 0.101 | eV/atom , eV/$ \AA $ | -| test_ood_slabs | 0.025, 0.091 | eV/atom , eV/$ \AA $ | diff --git a/pretrained/AQCAT25/aqcat.json b/pretrained/AQCAT25/aqcat.json deleted file mode 100644 index 80e876d..0000000 --- a/pretrained/AQCAT25/aqcat.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "t1", - "target": "energy_force", - "train_dataset": "aqcat", - "valid_dataset": "test_ood_ads", - "test_dataset": "test_ood_mat", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 4, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": false, - "main_chi1": 64, - "mp_chi1": 48, - "chi2": 32, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 128, - "cutoff": 6, - "num_radial": 8 - }, - "train": { - "epochs": 500, - "batch_size":32, - "accumulation_steps": 1, - "vt_batch_size": 32, - "lr": 0.0002, - "lr_decay_factor": 0.9, - "lr_decay_step_size": 50000, - "weight_decay": 0.0001, - "save_dir": "./dac", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "steplr", - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "stress_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "stress_coef": 0.5, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/pretrained/AQCAT25/aqcat_1021.ckpt b/pretrained/AQCAT25/aqcat_1021.ckpt deleted file mode 100644 index d79ac9a..0000000 Binary files a/pretrained/AQCAT25/aqcat_1021.ckpt and /dev/null differ diff --git a/pretrained/AQCAT25/haiku_model_converted/conversion_map.txt b/pretrained/AQCAT25/haiku_model_converted/conversion_map.txt deleted file mode 100644 index 1371d3a..0000000 --- a/pretrained/AQCAT25/haiku_model_converted/conversion_map.txt +++ /dev/null @@ -1,160 +0,0 @@ -Flax Key -> Haiku Key Mapping --------------------------------------------------------------------------------- -params/a -> alpha_net_hiku/a -params/b -> alpha_net_hiku/b -params/ftes_0/vec_proj/kernel -> alpha_net_hiku/~/fte/~/linear/w -params/ftes_0/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte/~/linear_1/b -params/ftes_0/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte/~/linear_1/w -params/ftes_0/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte/~/linear_2/b -params/ftes_0/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte/~/linear_2/w -params/ftes_1/vec_proj/kernel -> alpha_net_hiku/~/fte_1/~/linear/w -params/ftes_1/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_1/~/linear_1/b -params/ftes_1/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_1/~/linear_1/w -params/ftes_1/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_1/~/linear_2/b -params/ftes_1/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_1/~/linear_2/w -params/ftes_2/vec_proj/kernel -> alpha_net_hiku/~/fte_2/~/linear/w -params/ftes_2/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_2/~/linear_1/b -params/ftes_2/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_2/~/linear_1/w -params/ftes_2/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_2/~/linear_2/b -params/ftes_2/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_2/~/linear_2/w -params/ftes_3/vec_proj/kernel -> alpha_net_hiku/~/fte_3/~/linear/w -params/ftes_3/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_3/~/linear_1/b -params/ftes_3/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_3/~/linear_1/w -params/ftes_3/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_3/~/linear_2/b -params/ftes_3/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_3/~/linear_2/w -params/kernel1 -> alpha_net_hiku/kernel1 -params/kernels_imag -> alpha_net_hiku/kernels_imag -params/kernels_real -> alpha_net_hiku/kernels_real -params/last_layer/bias -> alpha_net_hiku/~/linear_4/b -params/last_layer/kernel -> alpha_net_hiku/~/linear_4/w -params/last_layer_quantum/bias -> alpha_net_hiku/~/linear_5/b -params/last_layer_quantum/kernel -> alpha_net_hiku/~/linear_5/w -params/lin/layers_0/bias -> alpha_net_hiku/~/linear_2/b -params/lin/layers_0/kernel -> alpha_net_hiku/~/linear_2/w -params/lin/layers_2/bias -> alpha_net_hiku/~/linear_3/b -params/lin/layers_2/kernel -> alpha_net_hiku/~/linear_3/w -params/message_layers_0/dia/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/b -params/message_layers_0/dia/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/w -params/message_layers_0/diachi1 -> alpha_net_hiku/~/equi_message_passing/diachi1 -params/message_layers_0/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/b -params/message_layers_0/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/w -params/message_layers_0/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/b -params/message_layers_0/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/w -params/message_layers_0/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/b -params/message_layers_0/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/w -params/message_layers_0/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/b -params/message_layers_0/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/w -params/message_layers_0/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/offset -params/message_layers_0/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/scale -params/message_layers_0/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing/~/linear/b -params/message_layers_0/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing/~/linear/w -params/message_layers_0/kernel_imag -> alpha_net_hiku/~/equi_message_passing/kernel_imag -params/message_layers_0/kernel_real -> alpha_net_hiku/~/equi_message_passing/kernel_real -params/message_layers_0/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing/linear/b -params/message_layers_0/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing/linear/w -params/message_layers_0/scale/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear/b -params/message_layers_0/scale/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear/w -params/message_layers_0/scale2/bias -> alpha_net_hiku/~/equi_message_passing/linear_1/b -params/message_layers_0/scale2/kernel -> alpha_net_hiku/~/equi_message_passing/linear_1/w -params/message_layers_0/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm/offset -params/message_layers_0/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm/scale -params/message_layers_0/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/b -params/message_layers_0/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/w -params/message_layers_0/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/b -params/message_layers_0/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/w -params/message_layers_1/dia/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/b -params/message_layers_1/dia/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/w -params/message_layers_1/diachi1 -> alpha_net_hiku/~/equi_message_passing_1/diachi1 -params/message_layers_1/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/b -params/message_layers_1/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/w -params/message_layers_1/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/b -params/message_layers_1/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/w -params/message_layers_1/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/b -params/message_layers_1/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/w -params/message_layers_1/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/b -params/message_layers_1/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/w -params/message_layers_1/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/offset -params/message_layers_1/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/scale -params/message_layers_1/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_1/~/linear/b -params/message_layers_1/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_1/~/linear/w -params/message_layers_1/kernel_imag -> alpha_net_hiku/~/equi_message_passing_1/kernel_imag -params/message_layers_1/kernel_real -> alpha_net_hiku/~/equi_message_passing_1/kernel_real -params/message_layers_1/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_1/linear/b -params/message_layers_1/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear/w -params/message_layers_1/scale/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/b -params/message_layers_1/scale/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/w -params/message_layers_1/scale2/bias -> alpha_net_hiku/~/equi_message_passing_1/linear_1/b -params/message_layers_1/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear_1/w -params/message_layers_1/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/offset -params/message_layers_1/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/scale -params/message_layers_1/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/b -params/message_layers_1/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/w -params/message_layers_1/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/b -params/message_layers_1/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/w -params/message_layers_2/dia/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/b -params/message_layers_2/dia/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/w -params/message_layers_2/diachi1 -> alpha_net_hiku/~/equi_message_passing_2/diachi1 -params/message_layers_2/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/b -params/message_layers_2/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/w -params/message_layers_2/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/b -params/message_layers_2/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/w -params/message_layers_2/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/b -params/message_layers_2/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/w -params/message_layers_2/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/b -params/message_layers_2/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/w -params/message_layers_2/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/offset -params/message_layers_2/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/scale -params/message_layers_2/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_2/~/linear/b -params/message_layers_2/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_2/~/linear/w -params/message_layers_2/kernel_imag -> alpha_net_hiku/~/equi_message_passing_2/kernel_imag -params/message_layers_2/kernel_real -> alpha_net_hiku/~/equi_message_passing_2/kernel_real -params/message_layers_2/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_2/linear/b -params/message_layers_2/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear/w -params/message_layers_2/scale/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/b -params/message_layers_2/scale/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/w -params/message_layers_2/scale2/bias -> alpha_net_hiku/~/equi_message_passing_2/linear_1/b -params/message_layers_2/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear_1/w -params/message_layers_2/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/offset -params/message_layers_2/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/scale -params/message_layers_2/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/b -params/message_layers_2/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/w -params/message_layers_2/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/b -params/message_layers_2/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/w -params/message_layers_3/dia/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/b -params/message_layers_3/dia/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/w -params/message_layers_3/diachi1 -> alpha_net_hiku/~/equi_message_passing_3/diachi1 -params/message_layers_3/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/b -params/message_layers_3/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/w -params/message_layers_3/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/b -params/message_layers_3/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/w -params/message_layers_3/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/b -params/message_layers_3/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/w -params/message_layers_3/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/b -params/message_layers_3/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/w -params/message_layers_3/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/offset -params/message_layers_3/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/scale -params/message_layers_3/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_3/~/linear/b -params/message_layers_3/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_3/~/linear/w -params/message_layers_3/kernel_imag -> alpha_net_hiku/~/equi_message_passing_3/kernel_imag -params/message_layers_3/kernel_real -> alpha_net_hiku/~/equi_message_passing_3/kernel_real -params/message_layers_3/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_3/linear/b -params/message_layers_3/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear/w -params/message_layers_3/scale/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/b -params/message_layers_3/scale/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/w -params/message_layers_3/scale2/bias -> alpha_net_hiku/~/equi_message_passing_3/linear_1/b -params/message_layers_3/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear_1/w -params/message_layers_3/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/offset -params/message_layers_3/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/scale -params/message_layers_3/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/b -params/message_layers_3/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/w -params/message_layers_3/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/b -params/message_layers_3/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/w -params/neighbor_emb/Embed_0/embedding -> alpha_net_hiku/~/neighbor_emb/embed/embeddings -params/radial_emb/bessel_weights -> alpha_net_hiku/~/rbf_emb/bessel_weights -params/radial_lin/layers_0/bias -> alpha_net_hiku/~/linear/b -params/radial_lin/layers_0/kernel -> alpha_net_hiku/~/linear/w -params/radial_lin/layers_2/bias -> alpha_net_hiku/~/linear_1/b -params/radial_lin/layers_2/kernel -> alpha_net_hiku/~/linear_1/w -params/s_vector/Dense_0/bias -> alpha_net_hiku/~/s_vector/linear/b -params/s_vector/Dense_0/kernel -> alpha_net_hiku/~/s_vector/linear/w -params/z_emb/embedding -> alpha_net_hiku/~/embed/embeddings diff --git a/pretrained/MATPES/README.md b/pretrained/MATPES/README.md deleted file mode 100644 index a25b077..0000000 --- a/pretrained/MATPES/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# AlphaNet-MATPES-r2scan - -A model trained on the [MATPES](https://matpes.ai/) R2SCAN level dataset. - -## Model Details - -* **Parameters:** Approximately 16M - -## Access the Model - -The following resources are available in the `pretrained_models/MATPES` path: - -* **Model Configuration:** `matpes.json` -* **Model state\_dict:** Pre-trained weights `r2scan_1021.ckpt` - -## Performance -Remain to be evaluated \ No newline at end of file diff --git a/pretrained/MATPES/haiku_model_converted/conversion_map.txt b/pretrained/MATPES/haiku_model_converted/conversion_map.txt deleted file mode 100644 index 1371d3a..0000000 --- a/pretrained/MATPES/haiku_model_converted/conversion_map.txt +++ /dev/null @@ -1,160 +0,0 @@ -Flax Key -> Haiku Key Mapping --------------------------------------------------------------------------------- -params/a -> alpha_net_hiku/a -params/b -> alpha_net_hiku/b -params/ftes_0/vec_proj/kernel -> alpha_net_hiku/~/fte/~/linear/w -params/ftes_0/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte/~/linear_1/b -params/ftes_0/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte/~/linear_1/w -params/ftes_0/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte/~/linear_2/b -params/ftes_0/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte/~/linear_2/w -params/ftes_1/vec_proj/kernel -> alpha_net_hiku/~/fte_1/~/linear/w -params/ftes_1/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_1/~/linear_1/b -params/ftes_1/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_1/~/linear_1/w -params/ftes_1/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_1/~/linear_2/b -params/ftes_1/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_1/~/linear_2/w -params/ftes_2/vec_proj/kernel -> alpha_net_hiku/~/fte_2/~/linear/w -params/ftes_2/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_2/~/linear_1/b -params/ftes_2/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_2/~/linear_1/w -params/ftes_2/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_2/~/linear_2/b -params/ftes_2/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_2/~/linear_2/w -params/ftes_3/vec_proj/kernel -> alpha_net_hiku/~/fte_3/~/linear/w -params/ftes_3/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_3/~/linear_1/b -params/ftes_3/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_3/~/linear_1/w -params/ftes_3/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_3/~/linear_2/b -params/ftes_3/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_3/~/linear_2/w -params/kernel1 -> alpha_net_hiku/kernel1 -params/kernels_imag -> alpha_net_hiku/kernels_imag -params/kernels_real -> alpha_net_hiku/kernels_real -params/last_layer/bias -> alpha_net_hiku/~/linear_4/b -params/last_layer/kernel -> alpha_net_hiku/~/linear_4/w -params/last_layer_quantum/bias -> alpha_net_hiku/~/linear_5/b -params/last_layer_quantum/kernel -> alpha_net_hiku/~/linear_5/w -params/lin/layers_0/bias -> alpha_net_hiku/~/linear_2/b -params/lin/layers_0/kernel -> alpha_net_hiku/~/linear_2/w -params/lin/layers_2/bias -> alpha_net_hiku/~/linear_3/b -params/lin/layers_2/kernel -> alpha_net_hiku/~/linear_3/w -params/message_layers_0/dia/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/b -params/message_layers_0/dia/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/w -params/message_layers_0/diachi1 -> alpha_net_hiku/~/equi_message_passing/diachi1 -params/message_layers_0/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/b -params/message_layers_0/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/w -params/message_layers_0/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/b -params/message_layers_0/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/w -params/message_layers_0/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/b -params/message_layers_0/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/w -params/message_layers_0/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/b -params/message_layers_0/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/w -params/message_layers_0/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/offset -params/message_layers_0/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/scale -params/message_layers_0/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing/~/linear/b -params/message_layers_0/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing/~/linear/w -params/message_layers_0/kernel_imag -> alpha_net_hiku/~/equi_message_passing/kernel_imag -params/message_layers_0/kernel_real -> alpha_net_hiku/~/equi_message_passing/kernel_real -params/message_layers_0/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing/linear/b -params/message_layers_0/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing/linear/w -params/message_layers_0/scale/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear/b -params/message_layers_0/scale/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear/w -params/message_layers_0/scale2/bias -> alpha_net_hiku/~/equi_message_passing/linear_1/b -params/message_layers_0/scale2/kernel -> alpha_net_hiku/~/equi_message_passing/linear_1/w -params/message_layers_0/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm/offset -params/message_layers_0/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm/scale -params/message_layers_0/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/b -params/message_layers_0/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/w -params/message_layers_0/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/b -params/message_layers_0/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/w -params/message_layers_1/dia/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/b -params/message_layers_1/dia/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/w -params/message_layers_1/diachi1 -> alpha_net_hiku/~/equi_message_passing_1/diachi1 -params/message_layers_1/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/b -params/message_layers_1/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/w -params/message_layers_1/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/b -params/message_layers_1/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/w -params/message_layers_1/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/b -params/message_layers_1/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/w -params/message_layers_1/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/b -params/message_layers_1/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/w -params/message_layers_1/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/offset -params/message_layers_1/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/scale -params/message_layers_1/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_1/~/linear/b -params/message_layers_1/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_1/~/linear/w -params/message_layers_1/kernel_imag -> alpha_net_hiku/~/equi_message_passing_1/kernel_imag -params/message_layers_1/kernel_real -> alpha_net_hiku/~/equi_message_passing_1/kernel_real -params/message_layers_1/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_1/linear/b -params/message_layers_1/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear/w -params/message_layers_1/scale/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/b -params/message_layers_1/scale/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/w -params/message_layers_1/scale2/bias -> alpha_net_hiku/~/equi_message_passing_1/linear_1/b -params/message_layers_1/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear_1/w -params/message_layers_1/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/offset -params/message_layers_1/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/scale -params/message_layers_1/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/b -params/message_layers_1/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/w -params/message_layers_1/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/b -params/message_layers_1/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/w -params/message_layers_2/dia/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/b -params/message_layers_2/dia/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/w -params/message_layers_2/diachi1 -> alpha_net_hiku/~/equi_message_passing_2/diachi1 -params/message_layers_2/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/b -params/message_layers_2/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/w -params/message_layers_2/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/b -params/message_layers_2/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/w -params/message_layers_2/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/b -params/message_layers_2/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/w -params/message_layers_2/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/b -params/message_layers_2/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/w -params/message_layers_2/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/offset -params/message_layers_2/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/scale -params/message_layers_2/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_2/~/linear/b -params/message_layers_2/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_2/~/linear/w -params/message_layers_2/kernel_imag -> alpha_net_hiku/~/equi_message_passing_2/kernel_imag -params/message_layers_2/kernel_real -> alpha_net_hiku/~/equi_message_passing_2/kernel_real -params/message_layers_2/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_2/linear/b -params/message_layers_2/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear/w -params/message_layers_2/scale/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/b -params/message_layers_2/scale/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/w -params/message_layers_2/scale2/bias -> alpha_net_hiku/~/equi_message_passing_2/linear_1/b -params/message_layers_2/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear_1/w -params/message_layers_2/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/offset -params/message_layers_2/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/scale -params/message_layers_2/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/b -params/message_layers_2/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/w -params/message_layers_2/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/b -params/message_layers_2/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/w -params/message_layers_3/dia/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/b -params/message_layers_3/dia/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/w -params/message_layers_3/diachi1 -> alpha_net_hiku/~/equi_message_passing_3/diachi1 -params/message_layers_3/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/b -params/message_layers_3/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/w -params/message_layers_3/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/b -params/message_layers_3/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/w -params/message_layers_3/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/b -params/message_layers_3/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/w -params/message_layers_3/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/b -params/message_layers_3/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/w -params/message_layers_3/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/offset -params/message_layers_3/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/scale -params/message_layers_3/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_3/~/linear/b -params/message_layers_3/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_3/~/linear/w -params/message_layers_3/kernel_imag -> alpha_net_hiku/~/equi_message_passing_3/kernel_imag -params/message_layers_3/kernel_real -> alpha_net_hiku/~/equi_message_passing_3/kernel_real -params/message_layers_3/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_3/linear/b -params/message_layers_3/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear/w -params/message_layers_3/scale/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/b -params/message_layers_3/scale/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/w -params/message_layers_3/scale2/bias -> alpha_net_hiku/~/equi_message_passing_3/linear_1/b -params/message_layers_3/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear_1/w -params/message_layers_3/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/offset -params/message_layers_3/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/scale -params/message_layers_3/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/b -params/message_layers_3/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/w -params/message_layers_3/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/b -params/message_layers_3/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/w -params/neighbor_emb/Embed_0/embedding -> alpha_net_hiku/~/neighbor_emb/embed/embeddings -params/radial_emb/bessel_weights -> alpha_net_hiku/~/rbf_emb/bessel_weights -params/radial_lin/layers_0/bias -> alpha_net_hiku/~/linear/b -params/radial_lin/layers_0/kernel -> alpha_net_hiku/~/linear/w -params/radial_lin/layers_2/bias -> alpha_net_hiku/~/linear_1/b -params/radial_lin/layers_2/kernel -> alpha_net_hiku/~/linear_1/w -params/s_vector/Dense_0/bias -> alpha_net_hiku/~/s_vector/linear/b -params/s_vector/Dense_0/kernel -> alpha_net_hiku/~/s_vector/linear/w -params/z_emb/embedding -> alpha_net_hiku/~/embed/embeddings diff --git a/pretrained/MATPES/haiku_model_converted/haiku_model.pkl b/pretrained/MATPES/haiku_model_converted/haiku_model.pkl deleted file mode 100644 index 5bfdbe1..0000000 Binary files a/pretrained/MATPES/haiku_model_converted/haiku_model.pkl and /dev/null differ diff --git a/pretrained/MATPES/matpes.json b/pretrained/MATPES/matpes.json deleted file mode 100644 index 2618eeb..0000000 --- a/pretrained/MATPES/matpes.json +++ /dev/null @@ -1,53 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "MP-TRAIN", - "target": "energy_force", - "train_dataset": "r2scan-train", - "valid_dataset": "r2scan-val", - - "test_dataset": "r2scan-val", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 4, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": true, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 176, - "main_chi1": 96, - "mp_chi1": 64, - "chi2": 24, - "cutoff": 5, - "num_radial": 8, - "zbl": true - }, - "train": { - "epochs": 200, - "batch_size": 16, - "accumulation_steps": 1, - "vt_batch_size": 16, - "lr": 0.0005, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "r2scan", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "norm_label": false, - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "stress_coef": 10, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/pretrained/MATPES/r2scan_1021.ckpt b/pretrained/MATPES/r2scan_1021.ckpt deleted file mode 100644 index 686c1c5..0000000 Binary files a/pretrained/MATPES/r2scan_1021.ckpt and /dev/null differ diff --git a/pretrained/MPtrj/README.md b/pretrained/MPtrj/README.md deleted file mode 100644 index c1749b2..0000000 --- a/pretrained/MPtrj/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# AlphaNet-MPtrj-v1 - -A model trained on the MpTrj dataset. - -## Model Details - -* **Parameters:** Approximately 4.5 million - -## Access the Model - -The following resources are available in the `pretrained_models/MPtrj` path: - -* **Model Configuration:** `mp.json` -* **Model state\_dict:** Pre-trained weights can be downloaded from [Figshare](https://ndownloader.figshare.com/files/53851133). - -## Performance on WBM Test Set - -The detailed evaluation metrics for the model on the `full_test_set` are as follows: - -| Metric | Value | Unit/Description | -| :--- | :--- | :--- | -| F1 | 0.789 | fraction | -| DAF | 4.312 | dimensionless | -| Precision | 0.74 | fraction | -| Recall | 0.846 | fraction | -| Accuracy | 0.923 | fraction | -| TPR | 0.846 | fraction | -| FPR | 0.062 | fraction | -| TNR | 0.938 | fraction | -| FNR | 0.154 | fraction | -| TP | 37311.0 | count | -| FP | 13119.0 | count | -| TN | 199752.0 | count | -| FN | 6781.0 | count | -| MAE | 0.04 | eV/atom | -| RMSE | 0.091 | eV/atom | -| R2 | 0.747 | dimensionless | \ No newline at end of file diff --git a/pretrained/MPtrj/mp.json b/pretrained/MPtrj/mp.json deleted file mode 100644 index 3abf01f..0000000 --- a/pretrained/MPtrj/mp.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "MP-TRAIN", - "target": "energy_force", - "train_dataset": "MP_train", - "valid_dataset": "MP-train", - "test_dataset": "MP_val", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 4, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": true, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 176, - "cutoff": 5, - "num_radial": 8 - }, - "train": { - "epochs": 200, - "batch_size": 6, - "accumulation_steps": 1, - "vt_batch_size": 2, - "lr": 0.0005, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "wbm", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "norm_label": false, - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "stress_coef": 10, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/pretrained/OMA/README.md b/pretrained/OMA/README.md index 7c5ea71..c54c4b9 100644 --- a/pretrained/OMA/README.md +++ b/pretrained/OMA/README.md @@ -28,6 +28,6 @@ Same size with **AlphaNet-MPtrj-v1**, trained on OMAT24, and finetuned on sALEX+ The following resources are available in the directory: * **Model Configuration**: `oma.json` -* **Model `state_dict`**: Pre-trained weights can be downloaded from [Figshare](https://ndownloader.figshare.com/files/53851139). +* **Model `state_dict`**: [alex_1212.ckpt](./alex_1212.ckpt). -**Path**: `pretrained_models/OMA` \ No newline at end of file +**Path**: `pretrained_models/OMA` diff --git a/pretrained/AQCAT25/haiku_model_converted/haiku_model.pkl b/pretrained/OMA/alex_1212.ckpt similarity index 59% rename from pretrained/AQCAT25/haiku_model_converted/haiku_model.pkl rename to pretrained/OMA/alex_1212.ckpt index e89d847..35df5cb 100644 Binary files a/pretrained/AQCAT25/haiku_model_converted/haiku_model.pkl and b/pretrained/OMA/alex_1212.ckpt differ diff --git a/requirements.txt b/requirements.txt index f314aec..9e4be01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cu121 - +numpy==1.26.4 torch==2.1.2 torch_geometric lightning diff --git a/run_lammps.py b/run_lammps.py new file mode 100644 index 0000000..00e670c --- /dev/null +++ b/run_lammps.py @@ -0,0 +1,116 @@ +import sys +import os +import torch +import lammps as lmp_lib +import lammps.mliap + +try: + import alphanet.infer.lammps_mliap_alphanet +except ImportError: + sys.path.append(os.getcwd()) + import alphanet.infer.lammps_mliap_alphanet + + +_ORIGINAL_JIT_LOAD = torch.jit.load +_ORIGINAL_TORCH_LOAD = torch.load + +GLOBAL_LOADED_MODEL = None + +def hijack_load(f, *args, **kwargs): + """ + 拦截 LAMMPS (或其 Python 接口) 的加载请求。 + 直接返回我们预加载好的 Python 对象。 + """ + f_str = str(f) + print(f"⚡ Intercepted load request for '{f_str}'") + + if GLOBAL_LOADED_MODEL is not None: + print(" => Returning pre-loaded LAMMPS_MLIAP_ALPHANET Object!") + return GLOBAL_LOADED_MODEL + + print(" => Warning: Global model not set. Fallback to original load.") + # 根据文件后缀猜测加载方式 + if f_str.endswith('.pt') or f_str.endswith('.pth'): + return _ORIGINAL_TORCH_LOAD(f, *args, **kwargs) + return _ORIGINAL_JIT_LOAD(f, *args, **kwargs) + +# 替换 torch 的加载函数,以便被 LAMMPS 调用时触发钩子 +torch.jit.load = hijack_load +torch.load = hijack_load + +# --------------------------------------------------------- +# 主程序 +# --------------------------------------------------------- +if __name__ == "__main__": + input_file = "sl.in" # LAMMPS 输入文件 + model_file = "sl.pt" # create_lammps_model.py 生成的文件 + + if not os.path.exists(input_file): + print(f"❌ Error: {input_file} not found!") + sys.exit(1) + + if not os.path.exists(model_file): + print(f"❌ Error: {model_file} not found!") + sys.exit(1) + + print(f"✅ PyTorch {torch.__version__} loaded.") + + # 1. 确定设备 + if torch.cuda.is_available(): + device = torch.device("cuda") + print(f"✅ Using device: CUDA (GPU)") + else: + device = torch.device("cpu") + print(f"⚠️ Using device: CPU") + + # 2. 预加载模型对象 (Pre-load) + print(f"🛠️ Loading Python object from {model_file}...") + try: + # 使用原始的 torch.load 加载保存的 Python 对象 + loaded_object = _ORIGINAL_TORCH_LOAD(model_file, map_location=device) + + # 确保模型及其内部参数都在正确的设备上 + if hasattr(loaded_object, 'model'): + loaded_object.model.to(device) + loaded_object.device = device # 更新对象内部记录的 device + loaded_object.model.eval() + + # 注册到全局变量,供钩子使用 + GLOBAL_LOADED_MODEL = loaded_object + print(" Model loaded and registered successfully.") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"❌ Failed to load model: {e}") + sys.exit(1) + + # 3. 配置 LAMMPS + # -k on g 1: 开启 Kokkos 并使用 1 个 GPU + # neigh half: Kokkos 默认要求,我们已经在 Python 端通过“智能对称化”解决了这个问题 + # newton on: 必须开启,用于跨进程的力通信 + cmd_args = [ + "-k", "on", "g", "1", + "-sf", "kk", + "-pk", "kokkos", "newton", "on", "neigh", "half" + ] + + try: + print(f"🚀 Initializing LAMMPS...") + lmp = lmp_lib.lammps(cmdargs=cmd_args) + + print("🔌 Activating ML-IAP Kokkos interface...") + # 这行代码会触发 C++ 调用 Python 来加载模型 + # 此时会命中我们的 hijack_load,并返回 GLOBAL_LOADED_MODEL + lammps.mliap.activate_mliappy_kokkos(lmp) + + print(f"📂 Executing {input_file}...") + lmp.file(input_file) + + print("🎉 LAMMPS simulation finished successfully.") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"❌ Error during simulation: {e}") + sys.exit(1) \ No newline at end of file diff --git a/water.json b/water.json deleted file mode 100644 index 657de49..0000000 --- a/water.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "t1", - "target": "energy_force", - "train_dataset": "water-train", - "valid_dataset": "water-test", - "test_dataset": "water-test", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 3, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": true, - "use_pbc": true, - "has_dropout_flag": true, - "hidden_channels": 128, - "cutoff": 4, - "num_radial": 8 - }, - "train": { - "epochs": 500, - "batch_size":1, - "accumulation_steps": 16, - "vt_batch_size": 1, - "lr": 0.0005, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "gap", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "stress_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 1.0, - "stress_coef": 0.5, - "force_coef": 100.0, - "eval_steps": 1 - } -}