From 78fb758faee328078bb1ab8b4755b16d62513644 Mon Sep 17 00:00:00 2001 From: Stephen Farr Date: Tue, 11 Nov 2025 11:48:38 +0100 Subject: [PATCH] enable torc compile with cudagraphs --- aimnet/calculators/aimnet2ase.py | 4 +- aimnet/calculators/calculator.py | 52 ++++++++++++++++++++- aimnet/models/aimnet2.py | 46 ++++++++++++++++--- aimnet/models/base.py | 15 ++++++- aimnet/modules/aev.py | 12 ++++- aimnet/modules/core.py | 34 ++++++++++++-- aimnet/modules/lr.py | 77 +++++++++++++++++++++++++++----- aimnet/nbops.py | 32 +++++++++++++ aimnet/ops.py | 53 ++++++++++++++++++++-- examples/ase_md.py | 73 ++++++++++++++++++++++++++++++ 10 files changed, 368 insertions(+), 30 deletions(-) create mode 100644 examples/ase_md.py diff --git a/aimnet/calculators/aimnet2ase.py b/aimnet/calculators/aimnet2ase.py index 1caea5e..e744406 100644 --- a/aimnet/calculators/aimnet2ase.py +++ b/aimnet/calculators/aimnet2ase.py @@ -14,10 +14,10 @@ class AIMNet2ASE(Calculator): implemented_properties: ClassVar[list[str]] = ["energy", "forces", "free_energy", "charges", "stress", "dipole_moment"] - def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1): + def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1, compile_cuda_graphs: bool=False): super().__init__() if isinstance(base_calc, str): - base_calc = AIMNet2Calculator(base_calc) + base_calc = AIMNet2Calculator(base_calc, compile_cuda_graphs=compile_cuda_graphs) self.base_calc = base_calc self.reset() self.charge = charge diff --git a/aimnet/calculators/calculator.py b/aimnet/calculators/calculator.py index f94eff7..86fdd89 100644 --- a/aimnet/calculators/calculator.py +++ b/aimnet/calculators/calculator.py @@ -1,4 +1,5 @@ import warnings +import os from typing import Any, ClassVar, Dict, Literal import torch @@ -7,6 +8,19 @@ from .model_registry import get_model_path from .nbmat import TooManyNeighborsError, calc_nbmat +from aimnet.config import build_module + + + +def build_model(model_def): + """Build model from yaml + + function copied from test/test_model.py + """ + assert os.path.exists(model_def), f"Model definition file not found: {model_def}." + model = build_module(model_def) + assert isinstance(model, nn.Module), "The model is not an instance of AIMNet2." + return model class AIMNet2Calculator: """Genegic AIMNet2 calculator @@ -28,7 +42,7 @@ class AIMNet2Calculator: keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"] atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"] - def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320): + def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320, compile_cuda_graphs: bool=False): self.device = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(model, str): p = get_model_path(model) @@ -37,7 +51,24 @@ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320): self.model = model.to(self.device) else: raise TypeError("Invalid model type/name.") + + if compile_cuda_graphs: + self.compile_cuda_graphs = True + # to torch compile we need to un-jit + + # build the model using the yaml + aimnet2_d3_def = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2_dftd3_wb97m.yaml") + python_model = build_model(aimnet2_d3_def) + # copy weights from the torchscript model loaded above + python_model.load_state_dict(self.model.state_dict(), strict=False) + for p in python_model.parameters(): + p.requires_grad_(False) + python_model = python_model.eval() + else: + self.compile_cuda_graphs = False + + # setup things as usual from the torchscipt one self.cutoff = self.model.cutoff self.lr = hasattr(self.model, "cutoff_lr") self.cutoff_lr = getattr(self.model, "cutoff_lr", float("inf")) if self.lr else None @@ -57,6 +88,13 @@ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320): self._coulomb_method = coul_methods.pop() else: self._coulomb_method = None + + # now we stup and replace + if self.compile_cuda_graphs: + python_model.setup_for_compile_cudagraphs() + self.model = python_model.to(self.device) + # now we can compile with cuda-graphs enabled + self.model = torch.compile(self.model, fullgraph=True, options={'triton.cudagraphs':True}) def __call__(self, *args, **kwargs): return self.eval(*args, **kwargs) @@ -84,22 +122,33 @@ def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0: raise NotImplementedError("Hessian calculation is not supported for multiple molecules") data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian) + with torch.jit.optimized_execution(False): # type: ignore data = self.model(data) + data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian) data = self.process_output(data) return data def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]: + data = self.to_input_tensors(data) data = self.mol_flatten(data) if data.get("cell") is not None: + if self.compile_cuda_graphs: + #TODO: + raise NotImplementedError if data["mol_idx"][-1] > 0: raise NotImplementedError("PBC with multiple molecules is not implemented yet.") if self._coulomb_method == "simple": warnings.warn("Switching to DSF Coulomb for PBC", stacklevel=1) self.set_lrcoulomb_method("dsf") + if data["coord"].ndim == 2: + if self.compile_cuda_graphs: + # TODO: + raise NotImplementedError + data = self.make_nbmat(data) data = self.pad_input(data) return data @@ -246,6 +295,7 @@ def get_derivatives(self, data: Dict[str, Tensor], forces=False, stress=False, h training = getattr(self.model, "training", False) _create_graph = hessian or training x = [] + if hessian: forces = True if forces and ("forces" not in data or (_create_graph and not data["forces"].requires_grad)): diff --git a/aimnet/models/aimnet2.py b/aimnet/models/aimnet2.py index 18b7f4a..59ad696 100644 --- a/aimnet/models/aimnet2.py +++ b/aimnet/models/aimnet2.py @@ -89,6 +89,23 @@ def __init__( self.outputs = nn.ModuleDict(outputs) else: raise TypeError("`outputs` is not either list or dict") + + # flags for enableing compile and cudagraphs + self.compile_cudagraphs = False + self.nb_mode = -1 + + + def setup_for_compile_cudagraphs(self): + # for cuda graphs and torch compile we need to make some changes: + # any data dependent control flow needs to be changed to "attribute dependent" control flow + # any ops that create dynamically shaped tensors need to be changed to make statically shaped tensors + + # only nb_mode = 0 is currently enabled + self.nb_mode = 0 + self.aev.setup_compile_cudagraphs() + for m in self.outputs.children(): + m.setup_compile_cudagraphs() + self.compile_cudagraphs = True def _preprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: if "mult" not in data: @@ -105,7 +122,11 @@ def _postprocess_spin_polarized_charge(self, data: Dict[str, Tensor]) -> Dict[st return data def _prepare_in_a(self, data: Dict[str, Tensor]) -> Tensor: - a_i, a_j = nbops.get_ij(data["a"], data) + if not self.compile_cudagraphs: + a_i, a_j = nbops.get_ij(data["a"], data) + else: + a_i, a_j = nbops.get_ij_compile_cudagraphs(data["a"], data, self.nb_mode) + avf_a = self.conv_a(a_j, data["gs"], data["gv"]) if self.d2features: a_i = a_i.flatten(-2, -1) @@ -113,7 +134,11 @@ def _prepare_in_a(self, data: Dict[str, Tensor]) -> Tensor: return _in def _prepare_in_q(self, data: Dict[str, Tensor]) -> Tensor: - q_i, q_j = nbops.get_ij(data["charges"], data) + if not self.compile_cudagraphs: + q_i, q_j = nbops.get_ij(data["charges"], data) + else: + q_i, q_j = nbops.get_ij_compile_cudagraphs(data["charges"], data, self.nb_mode) + avf_q = self.conv_q(q_j, data["gs"], data["gv"]) _in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1) return _in @@ -128,11 +153,19 @@ def _update_q(self, data: Dict[str, Tensor], x: Tensor, delta_q: bool = True) -> dim=-1, ) # for loss - data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data) + if not self.compile_cudagraphs: + data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data) + else: + data["_delta_Q"] = data["charge"] - nbops.mol_sum_compile_cudagraphs(_q, data, self.nb_mode) + q = data["charges"] + _q if delta_q else _q data["charges_pre"] = q if self.num_charge_channels == 2 else q.squeeze(-1) f = _f.pow(2) - q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6) + if not self.compile_cudagraphs: + q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6) + else: + q = ops.nse_compile_cudagraphs(data["charge"], q, f, data, self.nb_mode, epsilon=1.0e-6) + data["charges"] = q data["a"] = data["a"] + delta_a.view_as(data["a"]) return data @@ -165,8 +198,9 @@ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: _in = torch.cat([self._prepare_in_a(data), self._prepare_in_q(data)], dim=-1) _out = mlp(_in) - if data["_input_padded"].item(): - _out = nbops.mask_i_(_out, data, mask_value=0.0) + if not self.compile_cudagraphs: + if data["_input_padded"].item(): + _out = nbops.mask_i_(_out, data, mask_value=0.0) if ipass == 0: data = self._update_q(data, _out, delta_q=False) diff --git a/aimnet/models/base.py b/aimnet/models/base.py index a0a5bf1..f865083 100644 --- a/aimnet/models/base.py +++ b/aimnet/models/base.py @@ -30,6 +30,9 @@ class AIMNet2Base(nn.Module): # pylint: disable=abstract-method def __init__(self): super().__init__() + # flags for enableing compile and cudagraphs + self.compile_cudagraphs = False + def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: for k, d in zip(self._required_keys, self._required_keys_dtype): assert k in data, f"Key {k} is required" @@ -42,8 +45,16 @@ def _prepare_dtype(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: def prepare_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: """Some sommon operations""" data = self._prepare_dtype(data) - data = nbops.set_nb_mode(data) - data = nbops.calc_masks(data) + + if not self.compile_cudagraphs: + data = nbops.set_nb_mode(data) + data = nbops.calc_masks(data) + else: + + + # check here we have one molecule + assert data["coord"].shape[0] == 1 + data = nbops.calc_masks_compile_cudagraphs(data, nb_mode=0) assert data["charge"].ndim == 1, "Charge should be 1D tensor." if "mult" in data: diff --git a/aimnet/modules/aev.py b/aimnet/modules/aev.py index a770c9b..6b11406 100644 --- a/aimnet/modules/aev.py +++ b/aimnet/modules/aev.py @@ -61,6 +61,12 @@ def __init__( self.dmat_fill = rc_s + # flag for enabling compile and cudagraphs + self.compile_cudagraphs=False + + def setup_compile_cudagraphs(self): + self.compile_cudagraphs=True + def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"): self.register_parameter( "rc" + mod, @@ -80,7 +86,11 @@ def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"): def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: # shapes (..., m) and (..., m, 3) - d_ij, r_ij = ops.calc_distances(data) + if not self.compile_cudagraphs: + d_ij, r_ij = ops.calc_distances(data) + else: + d_ij, r_ij = ops.calc_distances_compile_cudagraphs(data, nb_mode=0) + data["d_ij"] = d_ij # shapes (..., nshifts, m) and (..., nshifts, 3, m) u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data) # pylint: disable=unused-variable diff --git a/aimnet/modules/core.py b/aimnet/modules/core.py index d98b9db..d8cb171 100644 --- a/aimnet/modules/core.py +++ b/aimnet/modules/core.py @@ -94,6 +94,9 @@ def __init__( self.key_out = key_out self.reduce_sum = reduce_sum + def setup_compile_cudagraphs(self): + pass + def extra_repr(self) -> str: return f"key_in: {self.key_in}, key_out: {self.key_out}" @@ -111,11 +114,24 @@ def __init__(self, key_in: str, key_out: str): self.key_in = key_in self.key_out = key_out + # flags for compile and cudagraph functionality + self.compile_cuda_graphs = False + self.nb_mode = -1 + + def setup_compile_cudagraphs(self): + self.nb_mode = 0 + self.compile_cuda_graphs = True + + def extra_repr(self) -> str: return f"key_in: {self.key_in}, key_out: {self.key_out}" def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: - data[self.key_out] = nbops.mol_sum(data[self.key_in], data) + if not self.compile_cuda_graphs: + data[self.key_out] = nbops.mol_sum(data[self.key_in], data) + else: + data[self.key_out] = nbops.mol_sum_compile_cudagraphs(data[self.key_in], data, self.nb_mode) + return data @@ -128,13 +144,25 @@ def __init__(self, mlp: Dict | nn.Module, n_in: int, n_out: int, key_in: str, ke mlp = MLP(n_in=n_in, n_out=n_out, **mlp) self.mlp = mlp + # flags for compile and cudagraph functionality + self.compile_cuda_graphs = False + self.nb_mode = -1 + + def setup_compile_cudagraphs(self): + self.nb_mode = 0 + self.compile_cuda_graphs = True + def extra_repr(self) -> str: return f"key_in: {self.key_in}, key_out: {self.key_out}" def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: v = self.mlp(data[self.key_in]).squeeze(-1) - if data["_input_padded"].item(): - v = nbops.mask_i_(v, data, mask_value=0.0) + if not self.compile_cuda_graphs: + if data["_input_padded"].item(): + v = nbops.mask_i_(v, data, mask_value=0.0) + else: + #TODO: + pass data[self.key_out] = v return data diff --git a/aimnet/modules/lr.py b/aimnet/modules/lr.py index 9113d79..6322c4d 100644 --- a/aimnet/modules/lr.py +++ b/aimnet/modules/lr.py @@ -27,18 +27,39 @@ def __init__( self.method = method else: raise ValueError(f"Unknown method {method}") + self.compile_cudagraphs = False + + def setup_compile_cudagraphs(self): + self.compile_cudagraphs = True def coul_simple(self, data: Dict[str, Tensor]) -> Tensor: - data = ops.lazy_calc_dij_lr(data) + if not self.compile_cudagraphs: + data = ops.lazy_calc_dij_lr(data) + else: + # nb_mode 0 only + data["d_ij_lr"] = data["d_ij"] + d_ij = data["d_ij_lr"] q = data[self.key_in] - q_i, q_j = nbops.get_ij(q, data, suffix="_lr") + + if not self.compile_cudagraphs: + q_i, q_j = nbops.get_ij(q, data, suffix="_lr") + else: + # nb_mode 0 only + q_i, q_j = nbops.get_ij_compile_cudagraphs(q, data, 0, suffix="_lr") + + q_ij = q_i * q_j fc = 1.0 - ops.exp_cutoff(d_ij, self.rc) e_ij = fc * q_ij / d_ij e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix="_lr") e_i = e_ij.sum(-1) - e = self._factor * nbops.mol_sum(e_i, data) + if not self.compile_cudagraphs: + e = self._factor * nbops.mol_sum(e_i, data) + else: + # nb_mode 0 only + e = self._factor * nbops.mol_sum_compile_cudagraphs(e_i, data, 0) + return e def coul_simple_sr(self, data: Dict[str, Tensor]) -> Tensor: @@ -194,26 +215,46 @@ def __init__(self, s8: float, a1: float, a2: float, s6: float = 1.0, key_out="en self.register_buffer("cnmax", torch.zeros(95)) sd = constants.get_dftd3_param() self.load_state_dict(sd) + self.compile_cudagraphs = False + + def setup_compile_cudagraphs(self): + self.compile_cudagraphs = True def _calc_c6ij(self, data: Dict[str, Tensor]) -> Tensor: # CN part # short range for CN # d_ij = data["d_ij"] * constants.Bohr_inv - data = ops.lazy_calc_dij_lr(data) + if not self.compile_cudagraphs: + data = ops.lazy_calc_dij_lr(data) + else: + # nb mode 0 only + data["d_ij_lr"] = data["d_ij"] + d_ij = data["d_ij_lr"] * constants.Bohr_inv numbers = data["numbers"] - numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr") - rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr") + + if not self.compile_cudagraphs: + numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr") + rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr") + else: + numbers_i, numbers_j = nbops.get_ij_compile_cudagraphs(numbers, data, 0, suffix="_lr") + rcov_i, rcov_j = nbops.get_ij_compile_cudagraphs(self.rcov[numbers], data, 0, suffix="_lr") + rcov_ij = rcov_i + rcov_j cn_ij = 1.0 / (1.0 + torch.exp(self.k1 * (rcov_ij / d_ij - 1.0))) - cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, suffix="_lr") + cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, inplace=False, suffix="_lr") cn = cn_ij.sum(-1) cn = torch.clamp(cn, max=self.cnmax[numbers]).unsqueeze(-1).unsqueeze(-1) - cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr") + + if not self.compile_cudagraphs: + cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr") + else: + cn_i, cn_j = nbops.get_ij_compile_cudagraphs(cn, data, 0, suffix="_lr") + c6ab = self.c6ab[numbers_i, numbers_j] c6ref, cnref_i, cnref_j = torch.unbind(c6ab, dim=-1) - c6ref = nbops.mask_ij_(c6ref, data, 0.0, suffix="_lr") + c6ref = nbops.mask_ij_(c6ref, data, 0.0, inplace=False, suffix="_lr") l_ij = torch.exp(self.k3 * ((cn_i - cnref_i).pow(2) + (cn_j - cnref_j).pow(2))) w = l_ij.flatten(-2, -1).sum(-1) z = torch.einsum("...ij,...ij->...", c6ref, l_ij) @@ -226,15 +267,27 @@ def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: c6ij = self._calc_c6ij(data) rr = self.r4r2[data["numbers"]] - rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr") + + if not self.compile_cudagraphs: + rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr") + else: + rr_i, rr_j = nbops.get_ij_compile_cudagraphs(rr, data, 0, suffix="_lr") + rrij = 3 * rr_i * rr_j rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr") r0ij = self.a1 * rrij.sqrt() + self.a2 - ops.lazy_calc_dij_lr(data) + if not self.compile_cudagraphs: + ops.lazy_calc_dij_lr(data) + d_ij = data["d_ij_lr"] * constants.Bohr_inv e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8))) - e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data) + + if not self.compile_cudagraphs: + e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data) + else: + e = -constants.half_Hartree * nbops.mol_sum_compile_cudagraphs(e_ij.sum(-1), data, 0) + if self.key_out in data: data[self.key_out] = data[self.key_out] + e diff --git a/aimnet/nbops.py b/aimnet/nbops.py index 64d6406..f12df00 100644 --- a/aimnet/nbops.py +++ b/aimnet/nbops.py @@ -22,6 +22,23 @@ def get_nb_mode(data: Dict[str, Tensor]) -> int: """Get the neighbor model.""" return int(data["_nb_mode"].item()) +def calc_masks_compile_cudagraphs(data: Dict[str, Tensor], nb_mode: int) -> Dict[str, Tensor]: + """Calculate neighbor masks""" + if nb_mode != 0: + raise NotImplementedError + + data["mask_i"] = data["numbers"] == 0 + data["mask_ij"] = torch.eye( + data["numbers"].shape[1], device=data["numbers"].device, dtype=torch.bool + ).unsqueeze(0) + # data["_input_padded"] = False + data["_natom"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device) + data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device) + data["mask_ij_lr"] = data["mask_ij"] + + return data + + def calc_masks(data: Dict[str, Tensor]) -> Dict[str, Tensor]: """Calculate neighbor masks""" @@ -127,6 +144,13 @@ def get_ij(x: Tensor, data: Dict[str, Tensor], suffix: str = "") -> Tuple[Tensor raise ValueError(f"Invalid neighbor mode: {nb_mode}") return x_i, x_j +def get_ij_compile_cudagraphs(x: Tensor, data: Dict[str, Tensor], nb_mode: int, suffix: str = "") -> Tuple[Tensor, Tensor]: + if nb_mode == 0: + x_i = x.unsqueeze(2) + x_j = x.unsqueeze(1) + else: + raise NotADirectoryError + return x_i, x_j def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor: nb_mode = get_nb_mode(data) @@ -149,3 +173,11 @@ def mol_sum(x: Tensor, data: Dict[str, Tensor]) -> Tensor: else: raise ValueError(f"Invalid neighbor mode: {nb_mode}") return res + + +def mol_sum_compile_cudagraphs(x: Tensor, data: Dict[str, Tensor], nb_mode: int) -> Tensor: + if nb_mode in (0, 2): + res = x.sum(dim=1) + else: + raise NotImplementedError + return res diff --git a/aimnet/ops.py b/aimnet/ops.py index dce29d1..eb374e1 100644 --- a/aimnet/ops.py +++ b/aimnet/ops.py @@ -32,6 +32,17 @@ def calc_distances(data: Dict[str, Tensor], suffix: str = "", pad_value: float = d_ij = torch.norm(r_ij, p=2, dim=-1) return d_ij, r_ij +def calc_distances_compile_cudagraphs(data: Dict[str, Tensor], nb_mode: int, suffix: str = "", pad_value: float = 1.0) -> Tuple[Tensor, Tensor]: + if nb_mode != 0: + raise NotImplementedError + # nb_mode 0 only : + coord_i, coord_j = nbops.get_ij_compile_cudagraphs(data["coord"], data, nb_mode, suffix) + r_ij = coord_j - coord_i + r_ij = nbops.mask_ij_(r_ij, data, mask_value=pad_value, inplace=False, suffix=suffix) + d_ij = torch.norm(r_ij, p=2, dim=-1) + return d_ij, r_ij + + def center_coordinates(coord: Tensor, data: Dict[str, Tensor], masses: Optional[Tensor] = None) -> Tensor: if masses is not None: @@ -46,9 +57,12 @@ def center_coordinates(coord: Tensor, data: Dict[str, Tensor], masses: Optional[ return coord -def cosine_cutoff(d_ij: Tensor, rc: float) -> Tensor: - fc = 0.5 * (torch.cos(d_ij.clamp(min=1e-6, max=rc) * (math.pi / rc)) + 1.0) - return fc +# def cosine_cutoff(d_ij: Tensor, rc: float) -> Tensor: +# fc = 0.5 * (torch.cos(d_ij.clamp(min=1e-6, max=rc) * (math.pi / rc)) + 1.0) +# return fc + +def cosine_cutoff(d_ij: Tensor, rc: Tensor) -> Tensor: + return torch.where(d_ij < rc, 0.5 * (torch.cos(d_ij * (torch.pi / rc)) + 1), 0.0) def exp_cutoff(d: Tensor, rc: Tensor) -> Tensor: @@ -94,6 +108,39 @@ def nse( return q +def nse_compile_cudagraphs( + Q: Tensor, + q_u: Tensor, + f_u: Tensor, + data: Dict[str, Tensor], + nb_mode: int, + epsilon: float = 1.0e-6, +) -> Tensor: + # Q and q_u and f_u must have last dimension size 1 or 2 + F_u = nbops.mol_sum_compile_cudagraphs(f_u, data, nb_mode) + if epsilon > 0: + F_u = F_u + epsilon + Q_u = nbops.mol_sum_compile_cudagraphs(q_u, data, nb_mode) + dQ = Q - Q_u + # for loss + data["_dQ"] = dQ + + if nb_mode in (0, 2): + F_u = F_u.unsqueeze(-2) + dQ = dQ.unsqueeze(-2) + elif nb_mode == 1: + raise NotImplementedError + data["mol_sizes"][-1] += 1 + F_u = torch.repeat_interleave(F_u, data["mol_sizes"], dim=0) + dQ = torch.repeat_interleave(dQ, data["mol_sizes"], dim=0) + data["mol_sizes"][-1] -= 1 + else: + raise ValueError(f"Invalid neighbor mode: {nb_mode}") + f = f_u / F_u + q = q_u + f * dQ + return q + + def coulomb_matrix_dsf(d_ij: Tensor, Rc: float, alpha: float, data: Dict[str, Tensor]) -> Tensor: _c1 = (alpha * d_ij).erfc() / d_ij _c2 = math.erfc(alpha * Rc) / Rc diff --git a/examples/ase_md.py b/examples/ase_md.py new file mode 100644 index 0000000..b395924 --- /dev/null +++ b/examples/ase_md.py @@ -0,0 +1,73 @@ +import os +import sys + +from time import perf_counter + +import ase.io +from ase import units +from ase.md.langevin import Langevin +from ase.md import MDLogger + + +from aimnet.calculators import AIMNet2ASE + + +def torch_show_device_into(): + import torch + + print(f"Torch version: {torch.__version__}") + if torch.cuda.is_available(): + print(f"CUDA available, version {torch.version.cuda}, device: {torch.cuda.get_device_name()}") # type: ignore + else: + print("CUDA not available") + + +torch_show_device_into() +# 59 conformations of taxol +xyzfile = os.path.join(os.path.dirname(__file__), "taxol.xyz") + +# read the first one +atoms = ase.io.read(xyzfile, index=0) + +# create the calculator with default model and enable/disable torch compile with cudagraphs +COMPILE_CUDAGRAPHS=False +if COMPILE_CUDAGRAPHS: + print('running with torch_compile+cudagraphs') +else: + print('running without torch_compile+cudagraphs') + +calc = AIMNet2ASE(compile_cuda_graphs=COMPILE_CUDAGRAPHS) + +# attach the calculator to the atoms object +atoms.calc = calc # type: ignore + +# do a single point calculation to trigger compile and do a warmup step +forces = atoms.get_forces() +energy = atoms.get_potential_energy() +print('energy:', energy) + + +# setup MD +temperature_K: float = 300 +timestep: float = 1.0 * units.fs +friction: float = 0.01 / units.fs +traj_interval: int = 1000 +log_interval: int = 1000 +nsteps: int = 10000 + +dyn = Langevin(atoms, timestep, temperature_K=temperature_K, friction=friction) +dyn.attach( + lambda: ase.io.write("traj.xyz", atoms, append=True), interval=traj_interval +) +dyn.attach(MDLogger(dyn, atoms, sys.stdout), interval=log_interval) + + +# Run the dynamics +t1 = perf_counter() +dyn.run(steps=nsteps) +t2 = perf_counter() + +print(f"Completed MD in {t2 - t1:.1f} s ({(t2 - t1)*1000 / nsteps:.3f} ms/step)") + + +