-
Notifications
You must be signed in to change notification settings - Fork 26
Description
Hi, thanks for releasing AlphaNet and the pretrained models!
I’m trying to use AlphaNet-oma-v1 as an ASE calculator for NPT MD on a large system, running on a cluster with NVIDIA B200 (sm_100) GPUs. I ran into some issues:
Environment
- Cluster GPU: NVIDIA B200 (sm_100)
- Python: 3.10
- Conda env: /blue/.../conda_envs/alphanet
- PyTorch:
- Initially: 2.1.2+cu121
- Later: upgraded to a newer version following the PyTorch “Get started” guide for CUDA 12 / my GPU (same behavior for torch_scatter)
- NumPy: 1.26.4
- ASE: 3.26.0
- AlphaNet: cloned from GitHub, installed with pip install -e .
- JAX: installed with the CUDA 12 plugin as suggested in the README
- PyTorch backend on B200 (sm_100) + torch_scatter
Using:
from alphanet.infer.calc import AlphaNetCalculator
from alphanet.config import All_Config
config = All_Config().from_json(
"/blue/.../Software/AlphaNet/pretrained/OMA/oma.json"
)
calc = AlphaNetCalculator(
ckpt_path="/blue/.../Software/AlphaNet/pretrained/OMA/alex_0410.ckpt",
device="cuda",
precision="32",
config=config,
)
I get: GPU arch warning:
NVIDIA B200 with CUDA capability sm_100 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_70 sm_75 sm_80 sm_86 sm_90.
A torch_scatter ABI error on import:
/blue/.../conda_envs/alphanet/lib/python3.10/site-packages/torch_geometric/__init__.py:4: UserWarning: An issue occurred while importing 'torch-scatter'. Disabling its usage. Stacktrace: Could not load this library: /blue/.../torch_scatter/_version_cuda.so
OSError: /blue/.../torch_scatter/_version_cuda.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev
This happens already when importing AlphaNet (via from alphanet.models.alphanet import AlphaNet), before any MD run.
I later upgraded PyTorch to a newer version (following the official instructions for CUDA 12 and my GPU), but the torch_scatter error remains the same.
So at the moment I cannot reliably use the PyTorch backend on this machine (with device="cuda").
2. JAX/Haiku backend with OMA pretrained model
Because of the PyTorch issues, I tried the JAX backend:
from alphanet.infer.new_haiku import AlphaNetCalculator
from alphanet.config import All_Config
from ase.io import read
CKPT_PATH = "/blue/.../Software/AlphaNet/pretrained/OMA/haiku/haiku_params.pkl"
CONFIG_PATH = "/blue/.../Software/AlphaNet/pretrained/OMA/oma.json"
config = All_Config().from_json(CONFIG_PATH)
atoms = read("170LiQC.data", format="lammps-data",
Z_of_type={1: 1, 2: 3, 3: 6, 4: 7, 5: 8, 6: 9, 7: 16})
atoms.set_pbc([True, True, True])
calculator = AlphaNetCalculator(
ckpt_path=CKPT_PATH,
device="cuda",
precision="32",
config=config,
)
atoms.calc = calculator
On the first call I see JAX/XLA autotuning warnings (which I assume are benign), but then it crashes with:
Missing leaf parameters
alpha_net_hiku/a
alpha_net_hiku/b
alpha_net_hiku/kernel1
...
ValueError: Total of 158 missing leaves.
File ".../alphanet/infer/new_haiku.py", line 46, in __init__
check_params(override_leaves, base_leaves)
I also tried the common pattern in new_haiku.py:
with open(ckpt_path, "rb") as f:
raw_params = pickle.load(f)
if isinstance(raw_params, dict) and "params" in raw_params:
raw_params = raw_params["params"]
but I still get Total of 158 missing leaves, so it looks like the parameter tree in haiku_params.pkl does not match the current Haiku model code.
If I instead point the Haiku loader to the PyTorch Lightning checkpoint:
CKPT_PATH = "/blue/.../pretrained/OMA/alex_0410.ckpt"
I get:
_pickle.UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.
which is expected since that file is not a plain Haiku params pickle.
Questions
For the JAX/Haiku backend:
- What is the correct way to use the OMA pretrained model with alphanet.infer.new_haiku.AlphaNetCalculator?
- Is pretrained/OMA/haiku/haiku_params.pkl expected to match pretrained/OMA/oma.json and the current new_haiku.py?
- Could you provide a minimal working example that uses this Haiku checkpoint and config to compute the energy of a simple ASE system (e.g. bulk Cu), so I can compare with my setup?
For the PyTorch backend on sm_100 GPUs:
- Do you have a recommended PyTorch / torch_scatter / torch_geometric version that is known to work with AlphaNet on B200 (sm_100)?
- Or should we currently prefer the JAX backend on such hardware?
Documentation:
In the README, the comment # ./pretrained/OMA/haiku/haiku_params.pkl haiku ckpt is a bit ambiguous. It would be very helpful to explicitly state which checkpoint files are for PyTorch, which are for Haiku/JAX, and which calculator (calc vs new_haiku) should be used with each.
Thanks a lot in advance – I’m happy to test any suggested environment, branch, or checkpoint on the B200 node and report back performance and correctness.