Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# --- .gitignore 内容开始 ---

# 1. 忽略 LAMMPS 日志和轨迹文件
log.lammps
*.log
*.lammpstrj

# 2. 忽略数据结构文件
*.data
POSCAR
CH4.txt
*pt
*.pt
*.pkl
# 3. 忽略大模型权重文件 (通常不建议传大文件到 git,除非你需要)

# 4. 忽略 Python 编译缓存和打包文件
__pycache__/
*.egg-info/
build/
dist/
*.zip
lammps/
# --- .gitignore 内容结束 ---
84 changes: 19 additions & 65 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@

We present **AlphaNet**, a local frame-based equivariant model designed to tackle the challenges of achieving both accurate and efficient simulations for atomistic systems. **AlphaNet** enhances computational efficiency and accuracy by leveraging the local geometric structures of atomic environments through the construction of equivariant local frames and learnable frame transitions. And inspired by Quantum Mechanics, AlphaNet **introduces efficient multi-body message passing by using contraction of matrix product states** rather than common 2-body message passing. Notably, AlphaNet offers one of the best trade-offs between computational efficiency and accuracy among existing models. Moreover, AlphaNet exhibits scalability across a broad spectrum of system and dataset sizes, affirming its versatility.
markdown
## Update Log (v0.1.2)
## Update Log (v0.1.2-beta)

### Major Changes

1. **Added new 2 pretrained models**
- Provide a pretrained model for materials: **AlphaNet-MATPES-r2scan** and our first pretrained model for catlysis: **AlphaNet-AQCAT25**, see them in the [pretrained](./pretrained) folder.
- Users can **convert the checkpoint trained in torch to our JAX model**

2. **Fixed some bugs**
- Support non-periodic boundary conditions in our ase calculator.
- Fixed errors in float64
1. **Add lammps mliap interface**
2. **Slight change of model arch**
3. **Add finetune option**



## Installation Guide
Expand Down Expand Up @@ -84,7 +81,11 @@ alpha-train example.json # use --help to see more functions, like multi-gpu trai
```bash
alpha-conv -i in.ckpt -o out.ckpt # use --help to see more functions
```
3. Evaluate a model and draw diagonal plot:
2. Finetune a converted ckpt:
```bash
alpha-train example.json --finetune /path/to/your.ckpt
```
4. Evaluate a model and draw diagonal plot:
```bash
alpha-eval -c example.json -m /path/to/ckpt # use --help to see more functions
```
Expand Down Expand Up @@ -142,67 +143,17 @@ print(atoms.get_potential_energy())

```

### Using AlphaNet in JAX
1. Installation
```bash
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
```
This is just for reference. JAX installation may be tricky, please get more information in [JAX](https://docs.jax.dev/en/latest/installation.html) and its github issues.

Currently I suggest **version>=0.4 <=0.4.10 or >=0.4.30 <=0.5 or ==0.6.2**

Install flax and haiku
```bash
pip install matscipy
pip install flax
pip install -U dm-haiku
```

2. Converted checkpoints:

See pretrained directory

3. Convert a self-trained ckpt

First from torch to flax:
```bash
python scripts/conv_pt2flax.py #need to modify the path in it.
```
Then from flax to haiku:

```bash
python scripts/flax2haiku.py #need to modify the path in it.
```

4. Performance:

The output (energy forces stress) difference from torch model would below 0.001. I ran speed tests on a 4090 GPU, system size from 4 to 300, and get a **2.5x to 3x** speed up.

Please note jax model need to be compiled first, so the first run could take a few seconds or minutes, but would be pretty fast after that.

## Dataset Download

[The Defected Bilayer Graphene Dataset](https://zenodo.org/records/10374206)

[The Formate Decomposition on Cu Dataset](https://archive.materialscloud.org/record/2022.45)

[The Zeolite Dataset](https://doi.org/10.6084/m9.figshare.27800211)

[The OC dataset](https://opencatalystproject.org/)

[The MPtrj dataset](https://matbench-discovery.materialsproject.org/data)

## Pretrained Models

Current pretrained models:
Current pretrained models (due to the arch changes, previous pretrained models would need update, which will be done asap):

For materials:
- [AlphaNet-MPtrj-v1](pretrained/MPtrj): A model trained on the MpTrj dataset.
- [AlphaNet-oma-v1](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj.
- [AlphaNet-MATPES-r2scan](pretrained/MATPES): A model trained on the MATPES-r2scan dataset.

For surfaces adsorbtion and reactions:
- [AlphaNet-AQCAT25](pretrained/AQCAT25): A model trained on the AQCAT25 dataset.
- [AlphaNet-oma-v1.5](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj.

## Use AlphaNet in LAMMPS

See [mliap_lammps](mliap_lammps.md)

## License

Expand All @@ -222,3 +173,6 @@ We thank all contributors and the community for their support. Please open an is






6 changes: 4 additions & 2 deletions alphanet/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def display_config_table(main_config, runtime_config):
@click.option("--num_devices", type=int, default=1, help="GPUs per node")
@click.option("--resume", is_flag=True, help="Resume training from checkpoint")
@click.option("--ckpt_path", type=click.Path(), default=None, help="Path to checkpoint file")
def main(config, num_nodes, num_devices, resume, ckpt_path):
@click.option("--finetune", type=click.Path(exists=True), default=None, help="Path to pretrained checkpoint for finetuning (resets optimizer)")
def main(config, num_nodes, num_devices, resume, ckpt_path, finetune):

with open(config, "r") as f:
mconfig = json.load(f)
Expand All @@ -67,7 +68,8 @@ def main(config, num_nodes, num_devices, resume, ckpt_path):
"num_nodes": num_nodes,
"num_devices": num_devices,
"resume": resume,
"ckpt_path": ckpt_path
"ckpt_path": ckpt_path,
"finetune_path": finetune
}

display_header()
Expand Down
11 changes: 9 additions & 2 deletions alphanet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import subprocess
import json
import torch
#import torch
from typing import Literal, Dict, Optional
from pydantic_settings import BaseSettings

Expand All @@ -22,6 +22,7 @@ class TrainConfig(BaseSettings):
batch_size: int = 32
vt_batch_size: int = 32
lr: float = 0.0005
optimizer: str = "radam"
lr_decay_factor: float = 0.5
lr_decay_step_size: int = 150
weight_decay: float = 0
Expand Down Expand Up @@ -86,7 +87,13 @@ class AlphaConfig(BaseSettings):
has_norm_after_flag: bool = False
reduce_mode: str = "sum"
zbl: bool = False
device: torch.device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
zbl_w: Optional[list] = [0.187,0.3769,0.189,0.081,0.003,0.037,0.0546,0.0715]
zbl_b: Optional[list] = [3.20,1.10,0.102,0.958,1.28,1.14,1.69,5]
zbl_gamma: float = 1.001
zbl_alpha: float = 0.6032
zbl_E2: float = 14.399645478425
zbl_A0: float = 0.529177210903
device: str = "cuda"



Expand Down
91 changes: 91 additions & 0 deletions alphanet/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
import os
import torch
from pathlib import Path

# Import the AlphaNet model wrapper and config
from alphanet.models.model import AlphaNetWrapper
from alphanet.config import All_Config

# Import the Python-level LAMMPS interface class
try:
from alphanet.infer.lammps_mliap_alphanet import LAMMPS_MLIAP_ALPHANET
except ImportError:
print("Could not import LAMMPS_MLIAP_ALPHANET.")
print("Please ensure 'alphanet/infer/lammps_mliap_alphanet.py' exists.")
exit(1)


def parse_args():
parser = argparse.ArgumentParser(
description="Convert an AlphaNet model to LAMMPS ML-IAP format (Python Pickle)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config", "-c", required=True, type=str,
help="Path to the model configuration JSON file",
)
parser.add_argument(
"--checkpoint", "-m", required=True, type=str,
help="Path to the trained model checkpoint (.ckpt)",
)
parser.add_argument(
"--output", "-o", required=True, type=str,
help="Output path to save the model (e.g., alphanet_lammps.pt)",
)
parser.add_argument(
"--device", type=str, default="cpu",
help="Device to load the model on ('cpu' or 'cuda')",
)
parser.add_argument(
"--dtype", type=str, default="float64",
choices=["float32", "float64"],
help="Data type for the model",
)
return parser.parse_args()

def main():
args = parse_args()

device = torch.device(args.device)

print(f"1. Loading configuration from {args.config}...")
config_obj = All_Config().from_json(args.config)

config_obj.model.dtype = "64" if args.dtype == "float64" else "32"

print(f"2. Initializing AlphaNetWrapper (precision: {args.dtype}, device: {args.device})...")
model_wrapper = AlphaNetWrapper(config_obj.model)

print(f"3. Loading weights from {args.checkpoint}...")
ckpt = torch.load(args.checkpoint, map_location=device)

if 'state_dict' in ckpt:
state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()}
model_wrapper.model.load_state_dict(state_dict, strict=False)
else:
model_wrapper.load_state_dict(ckpt, strict=False)

if args.dtype == "float64":
model_wrapper.double()
else:
model_wrapper.float()

model_wrapper.to(device).eval()

print("4. Creating LAMMPS ML-IAP Interface Object...")
lammps_interface_object = LAMMPS_MLIAP_ALPHANET(model_wrapper)

if device.type == 'cuda':
lammps_interface_object.model.cuda()

print(f"5. Saving Python object to {args.output}...")
# Using standard torch.save for Python pickle compatibility
torch.save(lammps_interface_object, args.output)

print("\n--- Success ---")
print(f"Created LAMMPS model file: {args.output}")
print("Usage in LAMMPS: pair_style mliap model/python ...")

if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion alphanet/infer/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs):
Calculator.__init__(self, **kwargs)

# --- Model Loading ---
if precision == "64":
config.dtype = '64'
if ckpt_path.endswith('ckpt'):
self.model = AlphaNetWrapper(config).to(torch.device(device))
# Load state dict, ignoring mismatches if any
Expand All @@ -42,7 +44,7 @@ def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs):
self.precision = torch.float32 if precision == "32" else torch.float64

if precision == "64":
self.model.double()
self.model.double()

self.model.eval() # Set model to evaluation mode
self.model.to(self.device)
Expand Down
Loading