Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@
}
}
}

1 change: 0 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@ updates:
labels:
- "dependencies"
- "ci"

1 change: 0 additions & 1 deletion .github/workflows/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ jobs:

- name: Run GPU tests
run: uv run pytest tests -m gpu -v

18 changes: 10 additions & 8 deletions aimnet/calculators/aimnet2ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ class AIMNet2ASE(Calculator):
"dipole_moment",
]

def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1, compile_mode: bool = False):
super().__init__()
if isinstance(base_calc, str):
base_calc = AIMNet2Calculator(base_calc)
base_calc = AIMNet2Calculator(base_calc, compile_mode=compile_mode)
self.base_calc = base_calc
self.reset()
self.charge = charge
Expand Down Expand Up @@ -86,13 +86,15 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
"mult": self._t_mult,
}

# In compile mode, calculator handles batching internally
_unsqueezed = False
if cell is not None:
_in["cell"] = cell
else:
for k, v in _in.items():
_in[k] = v.unsqueeze(0)
_unsqueezed = True
if not self.base_calc._compile_mode:
if cell is not None:
_in["cell"] = cell
else:
for k, v in _in.items():
_in[k] = v.unsqueeze(0)
_unsqueezed = True

results = self.base_calc(_in, forces="forces" in properties, stress="stress" in properties)

Expand Down
91 changes: 87 additions & 4 deletions aimnet/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch
from torch import Tensor, nn

from .model_registry import get_model_path
from aimnet.config import build_module

from .model_registry import get_model_definition_path, get_model_path
from .nbmat import TooManyNeighborsError, calc_nbmat


Expand All @@ -28,11 +30,27 @@ class AIMNet2Calculator:
keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"]
atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"]

def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320, compile_mode: bool = False):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._compile_mode = compile_mode
self._model_name: str | None = None

if compile_mode:
if not torch.cuda.is_available():
raise ValueError("compile_mode requires CUDA")
if not isinstance(model, str):
raise ValueError("compile_mode requires model name (str), not nn.Module")

if isinstance(model, str):
self._model_name = model
p = get_model_path(model)
self.model = torch.jit.load(p, map_location=self.device)
jit_model = torch.jit.load(p, map_location=self.device)

if compile_mode:
# Build native PyTorch model for torch.compile
self.model = self._build_compiled_model(model, jit_model)
else:
self.model = jit_model
elif isinstance(model, nn.Module):
self.model = model.to(self.device)
else:
Expand All @@ -58,6 +76,44 @@ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
else:
self._coulomb_method = None

def _build_compiled_model(self, model_name: str, jit_model: nn.Module) -> nn.Module:
"""Build a native PyTorch model for torch.compile.

This loads the YAML definition, builds a native model, copies weights
from the JIT model, enables compile mode, and applies torch.compile.

Args:
model_name: Name of the model
jit_model: The loaded JIT model to copy weights from

Returns:
Compiled native PyTorch model
"""
# Get YAML definition path
yaml_path = get_model_definition_path(model_name)

# Build native model from YAML
native_model: nn.Module = build_module(yaml_path) # type: ignore[assignment]
native_model = native_model.to(self.device)

# Copy weights from JIT model to native model
jit_state = jit_model.state_dict()
native_model.load_state_dict(jit_state)

# Enable compile mode (fixed nb_mode=0)
if hasattr(native_model, "enable_compile_mode"):
native_model.enable_compile_mode(nb_mode=0)

# Apply torch.compile with CUDA graphs
native_model.eval()
compiled_model = torch.compile(
native_model,
fullgraph=True,
options={"triton.cudagraphs": True},
)

return compiled_model

def __call__(self, *args, **kwargs):
return self.eval(*args, **kwargs)

Expand All @@ -84,14 +140,32 @@ def eval(self, data: dict[str, Any], forces=False, stress=False, hessian=False)
if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0:
raise NotImplementedError("Hessian calculation is not supported for multiple molecules")
data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian)
with torch.jit.optimized_execution(False): # type: ignore
if self._compile_mode:
# For compiled model, run directly
data = self.model(data)
else:
with torch.jit.optimized_execution(False): # type: ignore
data = self.model(data)
data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian)
data = self.process_output(data)
return data

def prepare_input(self, data: dict[str, Any]) -> dict[str, Tensor]:
data = self.to_input_tensors(data)

if self._compile_mode:
# Compile mode requires batch dim and nb_mode=0 (dense)
if data.get("cell") is not None:
raise NotImplementedError("PBC is not supported in compile mode")
# Ensure batch dimension for compile mode
if data["coord"].ndim == 2:
for k in ("coord", "numbers"):
data[k] = data[k].unsqueeze(0)
data["charge"] = data["charge"].reshape(1)
if "mult" in data:
data["mult"] = data["mult"].reshape(1)
return data

data = self.mol_flatten(data)
if data.get("cell") is not None:
if data["mol_idx"][-1] > 0:
Expand All @@ -105,6 +179,15 @@ def prepare_input(self, data: dict[str, Any]) -> dict[str, Tensor]:
return data

def process_output(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
if self._compile_mode:
# In compile mode, remove batch dim if we added it
if data["coord"].ndim == 3 and data["coord"].shape[0] == 1:
for k in self.atom_feature_keys:
if k in data:
data[k] = data[k].squeeze(0)
data = self.keep_only(data)
return data

if data["coord"].ndim == 2:
data = self.unpad_output(data)
data = self.mol_unflatten(data)
Expand Down
36 changes: 36 additions & 0 deletions aimnet/calculators/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,42 @@ def get_model_path(s: str) -> str:
return s


def get_model_definition_path(model_name: str) -> str:
"""Get the YAML definition file path for a model name.

This maps model names to their architecture definition YAML files,
which are needed for torch.compile (requires un-jitted model).

Args:
model_name: Model name or alias from the registry

Returns:
Path to the YAML model definition file
"""
model_registry = load_model_registry()

# Resolve alias first
if model_name in model_registry.get("aliases", {}):
model_name = model_registry["aliases"][model_name]

# Determine which YAML definition to use based on model name
# Models with D3 dispersion use aimnet2_dftd3_wb97m.yaml
# Models without D3 (like NSE) use aimnet2.yaml
models_dir = os.path.join(os.path.dirname(__file__), "..", "models")

if "nse" in model_name.lower():
# NSE models don't have D3
yaml_file = "aimnet2.yaml"
elif "d3" in model_name.lower() or "pd" in model_name.lower():
# D3 and Pd models include DFTD3
yaml_file = "aimnet2_dftd3_wb97m.yaml"
else:
# Default to D3 version for standard aimnet2 models
yaml_file = "aimnet2_dftd3_wb97m.yaml"

return os.path.join(models_dir, yaml_file)


@click.command(short_help="Clear assets directory.")
def clear_assets():
from glob import glob
Expand Down
14 changes: 9 additions & 5 deletions aimnet/models/aimnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,23 @@ def _postprocess_spin_polarized_charge(self, data: dict[str, Tensor]) -> dict[st
return data

def _prepare_in_a(self, data: dict[str, Tensor]) -> Tensor:
a_i, a_j = nbops.get_ij(data["a"], data)
compile_nb = self._compile_nb_mode if self._compile_mode else -1
a_i, a_j = nbops.get_ij(data["a"], data, compile_nb_mode=compile_nb)
avf_a = self.conv_a(a_j, data["gs"], data["gv"])
if self.d2features:
a_i = a_i.flatten(-2, -1)
_in = torch.cat([a_i.squeeze(-2), avf_a], dim=-1)
return _in

def _prepare_in_q(self, data: dict[str, Tensor]) -> Tensor:
q_i, q_j = nbops.get_ij(data["charges"], data)
compile_nb = self._compile_nb_mode if self._compile_mode else -1
q_i, q_j = nbops.get_ij(data["charges"], data, compile_nb_mode=compile_nb)
avf_q = self.conv_q(q_j, data["gs"], data["gv"])
_in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1)
return _in

def _update_q(self, data: dict[str, Tensor], x: Tensor, delta_q: bool = True) -> dict[str, Tensor]:
compile_nb = self._compile_nb_mode if self._compile_mode else -1
_q, _f, delta_a = x.split(
[
self.num_charge_channels,
Expand All @@ -127,11 +130,11 @@ def _update_q(self, data: dict[str, Tensor], x: Tensor, delta_q: bool = True) ->
dim=-1,
)
# for loss
data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data, compile_nb_mode=compile_nb)
q = data["charges"] + _q if delta_q else _q
data["charges_pre"] = q if self.num_charge_channels == 2 else q.squeeze(-1)
f = _f.pow(2)
q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6, compile_nb_mode=compile_nb)
data["charges"] = q
data["a"] = data["a"] + delta_a.view_as(data["a"])
return data
Expand Down Expand Up @@ -164,7 +167,8 @@ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
_in = torch.cat([self._prepare_in_a(data), self._prepare_in_q(data)], dim=-1)

_out = mlp(_in)
if data["_input_padded"].item():
# In compile mode, skip the .item() check - padding is handled at setup
if not self._compile_mode and data["_input_padded"].item():
_out = nbops.mask_i_(_out, data, mask_value=0.0)

if ipass == 0:
Expand Down
38 changes: 36 additions & 2 deletions aimnet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ class AIMNet2Base(nn.Module):

def __init__(self):
super().__init__()
# Compile mode attributes
self._compile_mode: bool = False
self._compile_nb_mode: int = -1 # -1 = dynamic, 0/1/2 = fixed

def enable_compile_mode(self, nb_mode: int = 0) -> None:
"""Enable compile mode with fixed nb_mode for CUDA graphs compatibility.

This sets up the model to use compile-time constant control flow,
avoiding .item() calls that break CUDA graph capture.

Args:
nb_mode: Fixed neighbor mode (0=dense, 1=sparse, 2=batched).
Currently only 0 is supported.
"""
if nb_mode not in (0, 1, 2):
raise ValueError(f"nb_mode must be 0, 1, or 2, got {nb_mode}")
if nb_mode != 0:
raise NotImplementedError(f"Compile mode only supports nb_mode=0 currently, got {nb_mode}")

self._compile_mode = True
self._compile_nb_mode = nb_mode

# Propagate to all submodules
for module in self.modules():
if hasattr(module, "_compile_mode"):
module._compile_mode = True
module._compile_nb_mode = nb_mode

def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
for k, d in zip(self._required_keys, self._required_keys_dtype, strict=False):
Expand All @@ -42,8 +69,15 @@ def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
"""Some sommon operations"""
data = self._prepare_dtype(data)
data = nbops.set_nb_mode(data)
data = nbops.calc_masks(data)

if self._compile_mode:
# In compile mode, use fixed nb_mode - no data-dependent branching
data["_nb_mode"] = torch.tensor(self._compile_nb_mode)
data = nbops.calc_masks_fixed_nb_mode(data, self._compile_nb_mode)
else:
# Dynamic mode - detect nb_mode from data
data = nbops.set_nb_mode(data)
data = nbops.calc_masks(data)

assert data["charge"].ndim == 1, "Charge should be 1D tensor."
if "mult" in data:
Expand Down
17 changes: 14 additions & 3 deletions aimnet/modules/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __init__(
shifts_v: list[float] | None = None,
):
super().__init__()
# Compile mode attributes
self._compile_mode: bool = False
self._compile_nb_mode: int = -1

self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_s")
if rc_v is not None:
Expand Down Expand Up @@ -79,7 +82,8 @@ def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"):

def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
# shapes (..., m) and (..., m, 3)
d_ij, r_ij = ops.calc_distances(data)
compile_nb = self._compile_nb_mode if self._compile_mode else -1
d_ij, r_ij = ops.calc_distances(data, compile_nb_mode=compile_nb)
data["d_ij"] = d_ij
# shapes (..., nshifts, m) and (..., nshifts, 3, m)
u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data)
Expand All @@ -88,14 +92,21 @@ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
return data

def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: dict[str, Tensor]) -> tuple[Tensor, Tensor, Tensor]:
fc_ij = ops.cosine_cutoff(d_ij, self.rc_s) # (..., m)
# In compile mode, use tensor version of cosine_cutoff for CUDA graph compatibility
if self._compile_mode:
fc_ij = ops.cosine_cutoff_tensor(d_ij, self.rc_s)
else:
fc_ij = ops.cosine_cutoff(d_ij, self.rc_s.item())
fc_ij = nbops.mask_ij_(fc_ij, data, 0.0)
gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze(
-1
) # (..., m, nshifts) * (..., m, 1) -> (..., m, shitfs)
u_ij = r_ij / d_ij.unsqueeze(-1) # (..., m, 3) / (..., m, 1) -> (..., m, 3)
if self._dual_basis:
fc_ij = ops.cosine_cutoff(d_ij, self.rc_v)
if self._compile_mode:
fc_ij = ops.cosine_cutoff_tensor(d_ij, self.rc_v)
else:
fc_ij = ops.cosine_cutoff(d_ij, self.rc_v.item())
gsv = ops.exp_expand(d_ij, self.shifts_v, self.eta_v) * fc_ij.unsqueeze(-1)
gv = gsv.unsqueeze(-2) * u_ij.unsqueeze(-1)
else:
Expand Down
Loading
Loading