Skip to content

AlphaNet-oma-v1 on NVIDIA B200: PyTorch GPU / torch_scatter issue and JAX/Haiku checkpoint “missing leaves” #7

@Rider99

Description

@Rider99

Hi, thanks for releasing AlphaNet and the pretrained models!
I’m trying to use AlphaNet-oma-v1 as an ASE calculator for NPT MD on a large system, running on a cluster with NVIDIA B200 (sm_100) GPUs. I ran into some issues:

Environment

  • Cluster GPU: NVIDIA B200 (sm_100)
  • Python: 3.10
  • Conda env: /blue/.../conda_envs/alphanet
  • PyTorch:
  • Initially: 2.1.2+cu121
  • Later: upgraded to a newer version following the PyTorch “Get started” guide for CUDA 12 / my GPU (same behavior for torch_scatter)
  • NumPy: 1.26.4
  • ASE: 3.26.0
  • AlphaNet: cloned from GitHub, installed with pip install -e .
  • JAX: installed with the CUDA 12 plugin as suggested in the README
  1. PyTorch backend on B200 (sm_100) + torch_scatter

Using:

from alphanet.infer.calc import AlphaNetCalculator
from alphanet.config import All_Config

config = All_Config().from_json(
    "/blue/.../Software/AlphaNet/pretrained/OMA/oma.json"
)

calc = AlphaNetCalculator(
    ckpt_path="/blue/.../Software/AlphaNet/pretrained/OMA/alex_0410.ckpt",
    device="cuda",
    precision="32",
    config=config,
)

I get: GPU arch warning:

NVIDIA B200 with CUDA capability sm_100 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_70 sm_75 sm_80 sm_86 sm_90.

A torch_scatter ABI error on import:

/blue/.../conda_envs/alphanet/lib/python3.10/site-packages/torch_geometric/__init__.py:4: UserWarning: An issue occurred while importing 'torch-scatter'. Disabling its usage. Stacktrace: Could not load this library: /blue/.../torch_scatter/_version_cuda.so
OSError: /blue/.../torch_scatter/_version_cuda.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev

This happens already when importing AlphaNet (via from alphanet.models.alphanet import AlphaNet), before any MD run.
I later upgraded PyTorch to a newer version (following the official instructions for CUDA 12 and my GPU), but the torch_scatter error remains the same.

So at the moment I cannot reliably use the PyTorch backend on this machine (with device="cuda").

2. JAX/Haiku backend with OMA pretrained model

Because of the PyTorch issues, I tried the JAX backend:

from alphanet.infer.new_haiku import AlphaNetCalculator
from alphanet.config import All_Config
from ase.io import read

CKPT_PATH = "/blue/.../Software/AlphaNet/pretrained/OMA/haiku/haiku_params.pkl"
CONFIG_PATH = "/blue/.../Software/AlphaNet/pretrained/OMA/oma.json"

config = All_Config().from_json(CONFIG_PATH)

atoms = read("170LiQC.data", format="lammps-data",
             Z_of_type={1: 1, 2: 3, 3: 6, 4: 7, 5: 8, 6: 9, 7: 16})
atoms.set_pbc([True, True, True])

calculator = AlphaNetCalculator(
    ckpt_path=CKPT_PATH,
    device="cuda",
    precision="32",
    config=config,
)
atoms.calc = calculator

On the first call I see JAX/XLA autotuning warnings (which I assume are benign), but then it crashes with:

Missing leaf parameters
   alpha_net_hiku/a
   alpha_net_hiku/b
   alpha_net_hiku/kernel1
   ...
ValueError: Total of 158 missing leaves.

  File ".../alphanet/infer/new_haiku.py", line 46, in __init__
    check_params(override_leaves, base_leaves)

I also tried the common pattern in new_haiku.py:

with open(ckpt_path, "rb") as f:
    raw_params = pickle.load(f)
if isinstance(raw_params, dict) and "params" in raw_params:
    raw_params = raw_params["params"]


but I still get Total of 158 missing leaves, so it looks like the parameter tree in haiku_params.pkl does not match the current Haiku model code.

If I instead point the Haiku loader to the PyTorch Lightning checkpoint:

CKPT_PATH = "/blue/.../pretrained/OMA/alex_0410.ckpt"

I get:

_pickle.UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.

which is expected since that file is not a plain Haiku params pickle.

Questions

For the JAX/Haiku backend:

  • What is the correct way to use the OMA pretrained model with alphanet.infer.new_haiku.AlphaNetCalculator?
  • Is pretrained/OMA/haiku/haiku_params.pkl expected to match pretrained/OMA/oma.json and the current new_haiku.py?
  • Could you provide a minimal working example that uses this Haiku checkpoint and config to compute the energy of a simple ASE system (e.g. bulk Cu), so I can compare with my setup?

For the PyTorch backend on sm_100 GPUs:

  • Do you have a recommended PyTorch / torch_scatter / torch_geometric version that is known to work with AlphaNet on B200 (sm_100)?
  • Or should we currently prefer the JAX backend on such hardware?

Documentation:

In the README, the comment # ./pretrained/OMA/haiku/haiku_params.pkl haiku ckpt is a bit ambiguous. It would be very helpful to explicitly state which checkpoint files are for PyTorch, which are for Haiku/JAX, and which calculator (calc vs new_haiku) should be used with each.

Thanks a lot in advance – I’m happy to test any suggested environment, branch, or checkpoint on the B200 node and report back performance and correctness.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions