Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aimnet/calculators/aimnet2ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class AIMNet2ASE(Calculator):

implemented_properties: ClassVar[list[str]] = ["energy", "forces", "free_energy", "charges", "stress", "dipole_moment"]

def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1, compile_cuda_graphs: bool=False):
super().__init__()
if isinstance(base_calc, str):
base_calc = AIMNet2Calculator(base_calc)
base_calc = AIMNet2Calculator(base_calc, compile_cuda_graphs=compile_cuda_graphs)
self.base_calc = base_calc
self.reset()
self.charge = charge
Expand Down
52 changes: 51 additions & 1 deletion aimnet/calculators/calculator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import os
from typing import Any, ClassVar, Dict, Literal

import torch
Expand All @@ -7,6 +8,19 @@
from .model_registry import get_model_path
from .nbmat import TooManyNeighborsError, calc_nbmat

from aimnet.config import build_module



def build_model(model_def):
"""Build model from yaml

function copied from test/test_model.py
"""
assert os.path.exists(model_def), f"Model definition file not found: {model_def}."
model = build_module(model_def)
assert isinstance(model, nn.Module), "The model is not an instance of AIMNet2."
return model

class AIMNet2Calculator:
"""Genegic AIMNet2 calculator
Expand All @@ -28,7 +42,7 @@ class AIMNet2Calculator:
keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"]
atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"]

def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320, compile_cuda_graphs: bool=False):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(model, str):
p = get_model_path(model)
Expand All @@ -37,7 +51,24 @@ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
self.model = model.to(self.device)
else:
raise TypeError("Invalid model type/name.")

if compile_cuda_graphs:
self.compile_cuda_graphs = True
# to torch compile we need to un-jit

# build the model using the yaml
aimnet2_d3_def = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2_dftd3_wb97m.yaml")
python_model = build_model(aimnet2_d3_def)
# copy weights from the torchscript model loaded above
python_model.load_state_dict(self.model.state_dict(), strict=False)
for p in python_model.parameters():
p.requires_grad_(False)
python_model = python_model.eval()
else:
self.compile_cuda_graphs = False


# setup things as usual from the torchscipt one
self.cutoff = self.model.cutoff
self.lr = hasattr(self.model, "cutoff_lr")
self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf")) if self.lr else None
Expand All @@ -57,6 +88,13 @@ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320):
self._coulomb_method = coul_methods.pop()
else:
self._coulomb_method = None

# now we stup and replace
if self.compile_cuda_graphs:
python_model.setup_for_compile_cudagraphs()
self.model = python_model.to(self.device)
# now we can compile with cuda-graphs enabled
self.model = torch.compile(self.model, fullgraph=True, options={'triton.cudagraphs':True})

def __call__(self, *args, **kwargs):
return self.eval(*args, **kwargs)
Expand Down Expand Up @@ -84,22 +122,33 @@ def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False)
if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0:
raise NotImplementedError("Hessian calculation is not supported for multiple molecules")
data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian)

with torch.jit.optimized_execution(False): # type: ignore
data = self.model(data)

data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian)
data = self.process_output(data)
return data

def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]:

data = self.to_input_tensors(data)
data = self.mol_flatten(data)
if data.get("cell") is not None:
if self.compile_cuda_graphs:
#TODO:
raise NotImplementedError
if data["mol_idx"][-1] > 0:
raise NotImplementedError("PBC with multiple molecules is not implemented yet.")
if self._coulomb_method == "simple":
warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1)
self.set_lrcoulomb_method("dsf")

if data["coord"].ndim == 2:
if self.compile_cuda_graphs:
# TODO:
raise NotImplementedError

data = self.make_nbmat(data)
data = self.pad_input(data)
return data
Expand Down Expand Up @@ -246,6 +295,7 @@ def get_derivatives(self, data: Dict[str, Tensor], forces=False, stress=False, h
training = getattr(self.model, "training", False)
_create_graph = hessian or training
x = []

if hessian:
forces = True
if forces and ("forces" not in data or (_create_graph and not data["forces"].requires_grad)):
Expand Down
46 changes: 40 additions & 6 deletions aimnet/models/aimnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def __init__(
self.outputs = nn.ModuleDict(outputs)
else:
raise TypeError("`outputs` is not either list or dict")

# flags for enableing compile and cudagraphs
self.compile_cudagraphs = False
self.nb_mode = -1


def setup_for_compile_cudagraphs(self):
# for cuda graphs and torch compile we need to make some changes:
# any data dependent control flow needs to be changed to "attribute dependent" control flow
# any ops that create dynamically shaped tensors need to be changed to make statically shaped tensors

# only nb_mode = 0 is currently enabled
self.nb_mode = 0
self.aev.setup_compile_cudagraphs()
for m in self.outputs.children():
m.setup_compile_cudagraphs()
self.compile_cudagraphs = True

def _preprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
if "mult" not in data:
Expand All @@ -105,15 +122,23 @@ def _postprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[st
return data

def _prepare_in_a(self, data: Dict[str, Tensor]) -> Tensor:
a_i, a_j = nbops.get_ij(data["a"], data)
if not self.compile_cudagraphs:
a_i, a_j = nbops.get_ij(data["a"], data)
else:
a_i, a_j = nbops.get_ij_compile_cudagraphs(data["a"], data, self.nb_mode)

avf_a = self.conv_a(a_j, data["gs"], data["gv"])
if self.d2features:
a_i = a_i.flatten(-2, -1)
_in = torch.cat([a_i.squeeze(-2), avf_a], dim=-1)
return _in

def _prepare_in_q(self, data: Dict[str, Tensor]) -> Tensor:
q_i, q_j = nbops.get_ij(data["charges"], data)
if not self.compile_cudagraphs:
q_i, q_j = nbops.get_ij(data["charges"], data)
else:
q_i, q_j = nbops.get_ij_compile_cudagraphs(data["charges"], data, self.nb_mode)

avf_q = self.conv_q(q_j, data["gs"], data["gv"])
_in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1)
return _in
Expand All @@ -128,11 +153,19 @@ def _update_q(self, data: Dict[str, Tensor], x: Tensor, delta_q: bool = True) ->
dim=-1,
)
# for loss
data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
if not self.compile_cudagraphs:
data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data)
else:
data["_delta_Q"] = data["charge"] - nbops.mol_sum_compile_cudagraphs(_q, data, self.nb_mode)

q = data["charges"] + _q if delta_q else _q
data["charges_pre"] = q if self.num_charge_channels == 2 else q.squeeze(-1)
f = _f.pow(2)
q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
if not self.compile_cudagraphs:
q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6)
else:
q = ops.nse_compile_cudagraphs(data["charge"], q, f, data, self.nb_mode, epsilon=1.0e-6)

data["charges"] = q
data["a"] = data["a"] + delta_a.view_as(data["a"])
return data
Expand Down Expand Up @@ -165,8 +198,9 @@ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
_in = torch.cat([self._prepare_in_a(data), self._prepare_in_q(data)], dim=-1)

_out = mlp(_in)
if data["_input_padded"].item():
_out = nbops.mask_i_(_out, data, mask_value=0.0)
if not self.compile_cudagraphs:
if data["_input_padded"].item():
_out = nbops.mask_i_(_out, data, mask_value=0.0)

if ipass == 0:
data = self._update_q(data, _out, delta_q=False)
Expand Down
15 changes: 13 additions & 2 deletions aimnet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class AIMNet2Base(nn.Module): # pylint: disable=abstract-method
def __init__(self):
super().__init__()

# flags for enableing compile and cudagraphs
self.compile_cudagraphs = False

def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
for k, d in zip(self._required_keys, self._required_keys_dtype):
assert k in data, f"Key {k} is required"
Expand All @@ -42,8 +45,16 @@ def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Some sommon operations"""
data = self._prepare_dtype(data)
data = nbops.set_nb_mode(data)
data = nbops.calc_masks(data)

if not self.compile_cudagraphs:
data = nbops.set_nb_mode(data)
data = nbops.calc_masks(data)
else:


# check here we have one molecule
assert data["coord"].shape[0] == 1
data = nbops.calc_masks_compile_cudagraphs(data, nb_mode=0)

assert data["charge"].ndim == 1, "Charge should be 1D tensor."
if "mult" in data:
Expand Down
12 changes: 11 additions & 1 deletion aimnet/modules/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def __init__(

self.dmat_fill = rc_s

# flag for enabling compile and cudagraphs
self.compile_cudagraphs=False

def setup_compile_cudagraphs(self):
self.compile_cudagraphs=True

def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"):
self.register_parameter(
"rc" + mod,
Expand All @@ -80,7 +86,11 @@ def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"):

def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
# shapes (..., m) and (..., m, 3)
d_ij, r_ij = ops.calc_distances(data)
if not self.compile_cudagraphs:
d_ij, r_ij = ops.calc_distances(data)
else:
d_ij, r_ij = ops.calc_distances_compile_cudagraphs(data, nb_mode=0)

data["d_ij"] = d_ij
# shapes (..., nshifts, m) and (..., nshifts, 3, m)
u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data) # pylint: disable=unused-variable
Expand Down
34 changes: 31 additions & 3 deletions aimnet/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(
self.key_out = key_out
self.reduce_sum = reduce_sum

def setup_compile_cudagraphs(self):
pass

def extra_repr(self) -> str:
return f"key_in: {self.key_in}, key_out: {self.key_out}"

Expand All @@ -111,11 +114,24 @@ def __init__(self, key_in: str, key_out: str):
self.key_in = key_in
self.key_out = key_out

# flags for compile and cudagraph functionality
self.compile_cuda_graphs = False
self.nb_mode = -1

def setup_compile_cudagraphs(self):
self.nb_mode = 0
self.compile_cuda_graphs = True


def extra_repr(self) -> str:
return f"key_in: {self.key_in}, key_out: {self.key_out}"

def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
data[self.key_out] = nbops.mol_sum(data[self.key_in], data)
if not self.compile_cuda_graphs:
data[self.key_out] = nbops.mol_sum(data[self.key_in], data)
else:
data[self.key_out] = nbops.mol_sum_compile_cudagraphs(data[self.key_in], data, self.nb_mode)

return data


Expand All @@ -128,13 +144,25 @@ def __init__(self, mlp: Dict | nn.Module, n_in: int, n_out: int, key_in: str, ke
mlp = MLP(n_in=n_in, n_out=n_out, **mlp)
self.mlp = mlp

# flags for compile and cudagraph functionality
self.compile_cuda_graphs = False
self.nb_mode = -1

def setup_compile_cudagraphs(self):
self.nb_mode = 0
self.compile_cuda_graphs = True

def extra_repr(self) -> str:
return f"key_in: {self.key_in}, key_out: {self.key_out}"

def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
v = self.mlp(data[self.key_in]).squeeze(-1)
if data["_input_padded"].item():
v = nbops.mask_i_(v, data, mask_value=0.0)
if not self.compile_cuda_graphs:
if data["_input_padded"].item():
v = nbops.mask_i_(v, data, mask_value=0.0)
else:
#TODO:
pass
data[self.key_out] = v
return data

Expand Down
Loading