diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a61f3b8..2fe2a0c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -39,4 +39,3 @@ } } } - diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 89bbf15..b3a568d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -22,4 +22,3 @@ updates: labels: - "dependencies" - "ci" - diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml index ae594c9..5a4af04 100644 --- a/.github/workflows/gpu-tests.yml +++ b/.github/workflows/gpu-tests.yml @@ -36,4 +36,3 @@ jobs: - name: Run GPU tests run: uv run pytest tests -m gpu -v - diff --git a/aimnet/calculators/aimnet2ase.py b/aimnet/calculators/aimnet2ase.py index 48cce3f..0a37fbb 100644 --- a/aimnet/calculators/aimnet2ase.py +++ b/aimnet/calculators/aimnet2ase.py @@ -21,10 +21,10 @@ class AIMNet2ASE(Calculator): "dipole_moment", ] - def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1): + def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1, compile_mode: bool = False): super().__init__() if isinstance(base_calc, str): - base_calc = AIMNet2Calculator(base_calc) + base_calc = AIMNet2Calculator(base_calc, compile_mode=compile_mode) self.base_calc = base_calc self.reset() self.charge = charge @@ -86,13 +86,15 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): "mult": self._t_mult, } + # In compile mode, calculator handles batching internally _unsqueezed = False - if cell is not None: - _in["cell"] = cell - else: - for k, v in _in.items(): - _in[k] = v.unsqueeze(0) - _unsqueezed = True + if not self.base_calc._compile_mode: + if cell is not None: + _in["cell"] = cell + else: + for k, v in _in.items(): + _in[k] = v.unsqueeze(0) + _unsqueezed = True results = self.base_calc(_in, forces="forces" in properties, stress="stress" in properties) diff --git a/aimnet/calculators/calculator.py b/aimnet/calculators/calculator.py index 3bd137e..9d4c740 100644 --- a/aimnet/calculators/calculator.py +++ b/aimnet/calculators/calculator.py @@ -4,7 +4,9 @@ import torch from torch import Tensor, nn -from .model_registry import get_model_path +from aimnet.config import build_module + +from .model_registry import get_model_definition_path, get_model_path from .nbmat import TooManyNeighborsError, calc_nbmat @@ -28,11 +30,27 @@ class AIMNet2Calculator: keys_out: ClassVar[list[str]] = ["energy", "charges", "forces", "hessian", "stress"] atom_feature_keys: ClassVar[list[str]] = ["coord", "numbers", "charges", "forces"] - def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320): + def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320, compile_mode: bool = False): self.device = "cuda" if torch.cuda.is_available() else "cpu" + self._compile_mode = compile_mode + self._model_name: str | None = None + + if compile_mode: + if not torch.cuda.is_available(): + raise ValueError("compile_mode requires CUDA") + if not isinstance(model, str): + raise ValueError("compile_mode requires model name (str), not nn.Module") + if isinstance(model, str): + self._model_name = model p = get_model_path(model) - self.model = torch.jit.load(p, map_location=self.device) + jit_model = torch.jit.load(p, map_location=self.device) + + if compile_mode: + # Build native PyTorch model for torch.compile + self.model = self._build_compiled_model(model, jit_model) + else: + self.model = jit_model elif isinstance(model, nn.Module): self.model = model.to(self.device) else: @@ -58,6 +76,44 @@ def __init__(self, model: str | nn.Module = "aimnet2", nb_threshold: int = 320): else: self._coulomb_method = None + def _build_compiled_model(self, model_name: str, jit_model: nn.Module) -> nn.Module: + """Build a native PyTorch model for torch.compile. + + This loads the YAML definition, builds a native model, copies weights + from the JIT model, enables compile mode, and applies torch.compile. + + Args: + model_name: Name of the model + jit_model: The loaded JIT model to copy weights from + + Returns: + Compiled native PyTorch model + """ + # Get YAML definition path + yaml_path = get_model_definition_path(model_name) + + # Build native model from YAML + native_model: nn.Module = build_module(yaml_path) # type: ignore[assignment] + native_model = native_model.to(self.device) + + # Copy weights from JIT model to native model + jit_state = jit_model.state_dict() + native_model.load_state_dict(jit_state) + + # Enable compile mode (fixed nb_mode=0) + if hasattr(native_model, "enable_compile_mode"): + native_model.enable_compile_mode(nb_mode=0) + + # Apply torch.compile with CUDA graphs + native_model.eval() + compiled_model = torch.compile( + native_model, + fullgraph=True, + options={"triton.cudagraphs": True}, + ) + + return compiled_model + def __call__(self, *args, **kwargs): return self.eval(*args, **kwargs) @@ -84,14 +140,32 @@ def eval(self, data: dict[str, Any], forces=False, stress=False, hessian=False) if hessian and "mol_idx" in data and data["mol_idx"][-1] > 0: raise NotImplementedError("Hessian calculation is not supported for multiple molecules") data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian) - with torch.jit.optimized_execution(False): # type: ignore + if self._compile_mode: + # For compiled model, run directly data = self.model(data) + else: + with torch.jit.optimized_execution(False): # type: ignore + data = self.model(data) data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian) data = self.process_output(data) return data def prepare_input(self, data: dict[str, Any]) -> dict[str, Tensor]: data = self.to_input_tensors(data) + + if self._compile_mode: + # Compile mode requires batch dim and nb_mode=0 (dense) + if data.get("cell") is not None: + raise NotImplementedError("PBC is not supported in compile mode") + # Ensure batch dimension for compile mode + if data["coord"].ndim == 2: + for k in ("coord", "numbers"): + data[k] = data[k].unsqueeze(0) + data["charge"] = data["charge"].reshape(1) + if "mult" in data: + data["mult"] = data["mult"].reshape(1) + return data + data = self.mol_flatten(data) if data.get("cell") is not None: if data["mol_idx"][-1] > 0: @@ -105,6 +179,15 @@ def prepare_input(self, data: dict[str, Any]) -> dict[str, Tensor]: return data def process_output(self, data: dict[str, Tensor]) -> dict[str, Tensor]: + if self._compile_mode: + # In compile mode, remove batch dim if we added it + if data["coord"].ndim == 3 and data["coord"].shape[0] == 1: + for k in self.atom_feature_keys: + if k in data: + data[k] = data[k].squeeze(0) + data = self.keep_only(data) + return data + if data["coord"].ndim == 2: data = self.unpad_output(data) data = self.mol_unflatten(data) diff --git a/aimnet/calculators/model_registry.py b/aimnet/calculators/model_registry.py index 2408297..4db2093 100644 --- a/aimnet/calculators/model_registry.py +++ b/aimnet/calculators/model_registry.py @@ -49,6 +49,42 @@ def get_model_path(s: str) -> str: return s +def get_model_definition_path(model_name: str) -> str: + """Get the YAML definition file path for a model name. + + This maps model names to their architecture definition YAML files, + which are needed for torch.compile (requires un-jitted model). + + Args: + model_name: Model name or alias from the registry + + Returns: + Path to the YAML model definition file + """ + model_registry = load_model_registry() + + # Resolve alias first + if model_name in model_registry.get("aliases", {}): + model_name = model_registry["aliases"][model_name] + + # Determine which YAML definition to use based on model name + # Models with D3 dispersion use aimnet2_dftd3_wb97m.yaml + # Models without D3 (like NSE) use aimnet2.yaml + models_dir = os.path.join(os.path.dirname(__file__), "..", "models") + + if "nse" in model_name.lower(): + # NSE models don't have D3 + yaml_file = "aimnet2.yaml" + elif "d3" in model_name.lower() or "pd" in model_name.lower(): + # D3 and Pd models include DFTD3 + yaml_file = "aimnet2_dftd3_wb97m.yaml" + else: + # Default to D3 version for standard aimnet2 models + yaml_file = "aimnet2_dftd3_wb97m.yaml" + + return os.path.join(models_dir, yaml_file) + + @click.command(short_help="Clear assets directory.") def clear_assets(): from glob import glob diff --git a/aimnet/models/aimnet2.py b/aimnet/models/aimnet2.py index 88c708e..9ba86e2 100644 --- a/aimnet/models/aimnet2.py +++ b/aimnet/models/aimnet2.py @@ -104,7 +104,8 @@ def _postprocess_spin_polarized_charge(self, data: dict[str, Tensor]) -> dict[st return data def _prepare_in_a(self, data: dict[str, Tensor]) -> Tensor: - a_i, a_j = nbops.get_ij(data["a"], data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + a_i, a_j = nbops.get_ij(data["a"], data, compile_nb_mode=compile_nb) avf_a = self.conv_a(a_j, data["gs"], data["gv"]) if self.d2features: a_i = a_i.flatten(-2, -1) @@ -112,12 +113,14 @@ def _prepare_in_a(self, data: dict[str, Tensor]) -> Tensor: return _in def _prepare_in_q(self, data: dict[str, Tensor]) -> Tensor: - q_i, q_j = nbops.get_ij(data["charges"], data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + q_i, q_j = nbops.get_ij(data["charges"], data, compile_nb_mode=compile_nb) avf_q = self.conv_q(q_j, data["gs"], data["gv"]) _in = torch.cat([q_i.squeeze(-2), avf_q], dim=-1) return _in def _update_q(self, data: dict[str, Tensor], x: Tensor, delta_q: bool = True) -> dict[str, Tensor]: + compile_nb = self._compile_nb_mode if self._compile_mode else -1 _q, _f, delta_a = x.split( [ self.num_charge_channels, @@ -127,11 +130,11 @@ def _update_q(self, data: dict[str, Tensor], x: Tensor, delta_q: bool = True) -> dim=-1, ) # for loss - data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data) + data["_delta_Q"] = data["charge"] - nbops.mol_sum(_q, data, compile_nb_mode=compile_nb) q = data["charges"] + _q if delta_q else _q data["charges_pre"] = q if self.num_charge_channels == 2 else q.squeeze(-1) f = _f.pow(2) - q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6) + q = ops.nse(data["charge"], q, f, data, epsilon=1.0e-6, compile_nb_mode=compile_nb) data["charges"] = q data["a"] = data["a"] + delta_a.view_as(data["a"]) return data @@ -164,7 +167,8 @@ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: _in = torch.cat([self._prepare_in_a(data), self._prepare_in_q(data)], dim=-1) _out = mlp(_in) - if data["_input_padded"].item(): + # In compile mode, skip the .item() check - padding is handled at setup + if not self._compile_mode and data["_input_padded"].item(): _out = nbops.mask_i_(_out, data, mask_value=0.0) if ipass == 0: diff --git a/aimnet/models/base.py b/aimnet/models/base.py index d2ece07..0250ff1 100644 --- a/aimnet/models/base.py +++ b/aimnet/models/base.py @@ -29,6 +29,33 @@ class AIMNet2Base(nn.Module): def __init__(self): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 # -1 = dynamic, 0/1/2 = fixed + + def enable_compile_mode(self, nb_mode: int = 0) -> None: + """Enable compile mode with fixed nb_mode for CUDA graphs compatibility. + + This sets up the model to use compile-time constant control flow, + avoiding .item() calls that break CUDA graph capture. + + Args: + nb_mode: Fixed neighbor mode (0=dense, 1=sparse, 2=batched). + Currently only 0 is supported. + """ + if nb_mode not in (0, 1, 2): + raise ValueError(f"nb_mode must be 0, 1, or 2, got {nb_mode}") + if nb_mode != 0: + raise NotImplementedError(f"Compile mode only supports nb_mode=0 currently, got {nb_mode}") + + self._compile_mode = True + self._compile_nb_mode = nb_mode + + # Propagate to all submodules + for module in self.modules(): + if hasattr(module, "_compile_mode"): + module._compile_mode = True + module._compile_nb_mode = nb_mode def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]: for k, d in zip(self._required_keys, self._required_keys_dtype, strict=False): @@ -42,8 +69,15 @@ def _prepare_dtype(self, data: dict[str, Tensor]) -> dict[str, Tensor]: def prepare_input(self, data: dict[str, Tensor]) -> dict[str, Tensor]: """Some sommon operations""" data = self._prepare_dtype(data) - data = nbops.set_nb_mode(data) - data = nbops.calc_masks(data) + + if self._compile_mode: + # In compile mode, use fixed nb_mode - no data-dependent branching + data["_nb_mode"] = torch.tensor(self._compile_nb_mode) + data = nbops.calc_masks_fixed_nb_mode(data, self._compile_nb_mode) + else: + # Dynamic mode - detect nb_mode from data + data = nbops.set_nb_mode(data) + data = nbops.calc_masks(data) assert data["charge"].ndim == 1, "Charge should be 1D tensor." if "mult" in data: diff --git a/aimnet/modules/aev.py b/aimnet/modules/aev.py index 8c5d5fa..afe6a52 100644 --- a/aimnet/modules/aev.py +++ b/aimnet/modules/aev.py @@ -44,6 +44,9 @@ def __init__( shifts_v: list[float] | None = None, ): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 self._init_basis(rc_s, eta_s, nshifts_s, shifts_s, rmin, mod="_s") if rc_v is not None: @@ -79,7 +82,8 @@ def _init_basis(self, rc, eta, nshifts, shifts, rmin, mod="_s"): def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: # shapes (..., m) and (..., m, 3) - d_ij, r_ij = ops.calc_distances(data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + d_ij, r_ij = ops.calc_distances(data, compile_nb_mode=compile_nb) data["d_ij"] = d_ij # shapes (..., nshifts, m) and (..., nshifts, 3, m) u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data) @@ -88,14 +92,21 @@ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: return data def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: dict[str, Tensor]) -> tuple[Tensor, Tensor, Tensor]: - fc_ij = ops.cosine_cutoff(d_ij, self.rc_s) # (..., m) + # In compile mode, use tensor version of cosine_cutoff for CUDA graph compatibility + if self._compile_mode: + fc_ij = ops.cosine_cutoff_tensor(d_ij, self.rc_s) + else: + fc_ij = ops.cosine_cutoff(d_ij, self.rc_s.item()) fc_ij = nbops.mask_ij_(fc_ij, data, 0.0) gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze( -1 ) # (..., m, nshifts) * (..., m, 1) -> (..., m, shitfs) u_ij = r_ij / d_ij.unsqueeze(-1) # (..., m, 3) / (..., m, 1) -> (..., m, 3) if self._dual_basis: - fc_ij = ops.cosine_cutoff(d_ij, self.rc_v) + if self._compile_mode: + fc_ij = ops.cosine_cutoff_tensor(d_ij, self.rc_v) + else: + fc_ij = ops.cosine_cutoff(d_ij, self.rc_v.item()) gsv = ops.exp_expand(d_ij, self.shifts_v, self.eta_v) * fc_ij.unsqueeze(-1) gv = gsv.unsqueeze(-2) * u_ij.unsqueeze(-1) else: diff --git a/aimnet/modules/core.py b/aimnet/modules/core.py index 1b450ec..d77ea9d 100644 --- a/aimnet/modules/core.py +++ b/aimnet/modules/core.py @@ -88,6 +88,10 @@ def __init__( reduce_sum=False, ): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 + shifts = nn.Embedding(num_types, 1, padding_idx=0, dtype=dtype) shifts.weight.requires_grad_(requires_grad) self.shifts = shifts @@ -101,7 +105,8 @@ def extra_repr(self) -> str: def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: shifts = self.shifts(data["numbers"]).squeeze(-1) if self.reduce_sum: - shifts = nbops.mol_sum(shifts, data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + shifts = nbops.mol_sum(shifts, data, compile_nb_mode=compile_nb) data[self.key_out] = data[self.key_in] + shifts return data @@ -109,6 +114,10 @@ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: class AtomicSum(nn.Module): def __init__(self, key_in: str, key_out: str): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 + self.key_in = key_in self.key_out = key_out @@ -116,13 +125,18 @@ def extra_repr(self) -> str: return f"key_in: {self.key_in}, key_out: {self.key_out}" def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - data[self.key_out] = nbops.mol_sum(data[self.key_in], data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + data[self.key_out] = nbops.mol_sum(data[self.key_in], data, compile_nb_mode=compile_nb) return data class Output(nn.Module): def __init__(self, mlp: dict | nn.Module, n_in: int, n_out: int, key_in: str, key_out: str): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 + self.key_in = key_in self.key_out = key_out if not isinstance(mlp, nn.Module): @@ -134,7 +148,8 @@ def extra_repr(self) -> str: def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: v = self.mlp(data[self.key_in]).squeeze(-1) - if data["_input_padded"].item(): + # In compile mode, skip the .item() check - padding is handled at setup + if not self._compile_mode and data["_input_padded"].item(): v = nbops.mask_i_(v, data, mask_value=0.0) data[self.key_out] = v return data diff --git a/aimnet/modules/lr.py b/aimnet/modules/lr.py index 3403000..386c0f7 100644 --- a/aimnet/modules/lr.py +++ b/aimnet/modules/lr.py @@ -15,6 +15,10 @@ def __init__( dsf_rc: float = 15.0, ): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 + self.key_in = key_in self.key_out = key_out self._factor = constants.half_Hartree * constants.Bohr @@ -27,38 +31,41 @@ def __init__( raise ValueError(f"Unknown method {method}") def coul_simple(self, data: dict[str, Tensor]) -> Tensor: - data = ops.lazy_calc_dij_lr(data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + data = ops.lazy_calc_dij_lr(data, compile_nb_mode=compile_nb) d_ij = data["d_ij_lr"] q = data[self.key_in] - q_i, q_j = nbops.get_ij(q, data, suffix="_lr") + q_i, q_j = nbops.get_ij(q, data, suffix="_lr", compile_nb_mode=compile_nb) q_ij = q_i * q_j fc = 1.0 - ops.exp_cutoff(d_ij, self.rc) e_ij = fc * q_ij / d_ij e_ij = nbops.mask_ij_(e_ij, data, 0.0, suffix="_lr") e_i = e_ij.sum(-1) - e = self._factor * nbops.mol_sum(e_i, data) + e = self._factor * nbops.mol_sum(e_i, data, compile_nb_mode=compile_nb) return e def coul_simple_sr(self, data: dict[str, Tensor]) -> Tensor: + compile_nb = self._compile_nb_mode if self._compile_mode else -1 d_ij = data["d_ij"] q = data[self.key_in] - q_i, q_j = nbops.get_ij(q, data) + q_i, q_j = nbops.get_ij(q, data, compile_nb_mode=compile_nb) q_ij = q_i * q_j fc = ops.exp_cutoff(d_ij, self.rc) e_ij = fc * q_ij / d_ij e_ij = nbops.mask_ij_(e_ij, data, 0.0) e_i = e_ij.sum(-1) - e = self._factor * nbops.mol_sum(e_i, data) + e = self._factor * nbops.mol_sum(e_i, data, compile_nb_mode=compile_nb) return e def coul_dsf(self, data: dict[str, Tensor]) -> Tensor: - data = ops.lazy_calc_dij_lr(data) + compile_nb = self._compile_nb_mode if self._compile_mode else -1 + data = ops.lazy_calc_dij_lr(data, compile_nb_mode=compile_nb) d_ij = data["d_ij_lr"] q = data[self.key_in] - q_i, q_j = nbops.get_ij(q, data, suffix="_lr") + q_i, q_j = nbops.get_ij(q, data, suffix="_lr", compile_nb_mode=compile_nb) J = ops.coulomb_matrix_dsf(d_ij, self.dsf_rc, self.dsf_alpha, data) e = (q_i * q_j * J).sum(-1) - e = self._factor * nbops.mol_sum(e, data) + e = self._factor * nbops.mol_sum(e, data, compile_nb_mode=compile_nb) e = e - self.coul_simple_sr(data) return e @@ -128,6 +135,10 @@ class D3TS(nn.Module): def __init__(self, a1: float, a2: float, s8: float, s6: float = 1.0, key_in="disp_param", key_out="energy"): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 + self.register_buffer("r4r2", constants.get_r4r2()) self.a1 = a1 self.a2 = a2 @@ -137,8 +148,9 @@ def __init__(self, a1: float, a2: float, s8: float, s6: float = 1.0, key_in="dis self.key_out = key_out def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: + compile_nb = self._compile_nb_mode if self._compile_mode else -1 disp_param = data[self.key_in] - disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix="_lr") + disp_param_i, disp_param_j = nbops.get_ij(disp_param, data, suffix="_lr", compile_nb_mode=compile_nb) c6_i, alpha_i = disp_param_i.unbind(dim=-1) c6_j, alpha_j = disp_param_j.unbind(dim=-1) @@ -146,15 +158,15 @@ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: c6ij = 2 * c6_i * c6_j / (c6_i * alpha_j / alpha_i + c6_j * alpha_i / alpha_j).clamp(min=1e-4) rr = self.r4r2[data["numbers"]] - rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr") + rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr", compile_nb_mode=compile_nb) rrij = 3 * rr_i * rr_j rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr") r0ij = self.a1 * rrij.sqrt() + self.a2 - ops.lazy_calc_dij_lr(data) + ops.lazy_calc_dij_lr(data, compile_nb_mode=compile_nb) d_ij = data["d_ij_lr"] * constants.Bohr_inv e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8))) - e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data) + e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data, compile_nb_mode=compile_nb) if self.key_out in data: data[self.key_out] = data[self.key_out] + e @@ -171,6 +183,10 @@ class DFTD3(nn.Module): def __init__(self, s8: float, a1: float, a2: float, s6: float = 1.0, key_out="energy"): super().__init__() + # Compile mode attributes + self._compile_mode: bool = False + self._compile_nb_mode: int = -1 + self.key_out = key_out # BJ damping parameters self.s6 = s6 @@ -191,21 +207,22 @@ def __init__(self, s8: float, a1: float, a2: float, s6: float = 1.0, key_out="en self.load_state_dict(sd) def _calc_c6ij(self, data: dict[str, Tensor]) -> Tensor: + compile_nb = self._compile_nb_mode if self._compile_mode else -1 # CN part # short range for CN # d_ij = data["d_ij"] * constants.Bohr_inv - data = ops.lazy_calc_dij_lr(data) + data = ops.lazy_calc_dij_lr(data, compile_nb_mode=compile_nb) d_ij = data["d_ij_lr"] * constants.Bohr_inv numbers = data["numbers"] - numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr") - rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr") + numbers_i, numbers_j = nbops.get_ij(numbers, data, suffix="_lr", compile_nb_mode=compile_nb) + rcov_i, rcov_j = nbops.get_ij(self.rcov[numbers], data, suffix="_lr", compile_nb_mode=compile_nb) rcov_ij = rcov_i + rcov_j cn_ij = 1.0 / (1.0 + torch.exp(self.k1 * (rcov_ij / d_ij - 1.0))) cn_ij = nbops.mask_ij_(cn_ij, data, 0.0, suffix="_lr") cn = cn_ij.sum(-1) cn = torch.clamp(cn, max=self.cnmax[numbers]).unsqueeze(-1).unsqueeze(-1) - cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr") + cn_i, cn_j = nbops.get_ij(cn, data, suffix="_lr", compile_nb_mode=compile_nb) c6ab = self.c6ab[numbers_i, numbers_j] c6ref, cnref_i, cnref_j = torch.unbind(c6ab, dim=-1) c6ref = nbops.mask_ij_(c6ref, data, 0.0, suffix="_lr") @@ -218,18 +235,19 @@ def _calc_c6ij(self, data: dict[str, Tensor]) -> Tensor: return c6_ij def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]: + compile_nb = self._compile_nb_mode if self._compile_mode else -1 c6ij = self._calc_c6ij(data) rr = self.r4r2[data["numbers"]] - rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr") + rr_i, rr_j = nbops.get_ij(rr, data, suffix="_lr", compile_nb_mode=compile_nb) rrij = 3 * rr_i * rr_j rrij = nbops.mask_ij_(rrij, data, 1.0, suffix="_lr") r0ij = self.a1 * rrij.sqrt() + self.a2 - ops.lazy_calc_dij_lr(data) + ops.lazy_calc_dij_lr(data, compile_nb_mode=compile_nb) d_ij = data["d_ij_lr"] * constants.Bohr_inv e_ij = c6ij * (self.s6 / (d_ij.pow(6) + r0ij.pow(6)) + self.s8 * rrij / (d_ij.pow(8) + r0ij.pow(8))) - e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data) + e = -constants.half_Hartree * nbops.mol_sum(e_ij.sum(-1), data, compile_nb_mode=compile_nb) if self.key_out in data: data[self.key_out] = data[self.key_out] + e diff --git a/aimnet/nbops.py b/aimnet/nbops.py index 444c54a..9625a50 100644 --- a/aimnet/nbops.py +++ b/aimnet/nbops.py @@ -21,6 +21,39 @@ def get_nb_mode(data: dict[str, Tensor]) -> int: return int(data["_nb_mode"].item()) +def calc_masks_fixed_nb_mode(data: dict[str, Tensor], nb_mode: int) -> dict[str, Tensor]: + """Calculate masks with a fixed (compile-time known) nb_mode. + + This avoids data-dependent control flow for CUDA graphs compatibility. + Used when torch.compile with CUDA graphs is enabled. + + Args: + data: Data dictionary + nb_mode: Fixed neighbor mode (0, 1, or 2) + + Returns: + Updated data dictionary with masks + """ + if nb_mode == 0: + data["mask_i"] = data["numbers"] == 0 + data["mask_ij"] = torch.eye( + data["numbers"].shape[1], device=data["numbers"].device, dtype=torch.bool + ).unsqueeze(0) + # In compile mode with nb_mode=0, we assume single non-padded molecule + data["_input_padded"] = torch.tensor(False) + data["_natom"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device) + data["mol_sizes"] = torch.tensor(data["numbers"].shape[1], device=data["numbers"].device) + data["mask_ij_lr"] = data["mask_ij"] + elif nb_mode == 1: + raise NotImplementedError("nb_mode=1 not yet supported in compile mode") + elif nb_mode == 2: + raise NotImplementedError("nb_mode=2 not yet supported in compile mode") + else: + raise ValueError(f"Invalid neighbor mode: {nb_mode}") + + return data + + def calc_masks(data: dict[str, Tensor]) -> dict[str, Tensor]: """Calculate neighbor masks""" nb_mode = get_nb_mode(data) @@ -108,8 +141,19 @@ def mask_i_(x: Tensor, data: dict[str, Tensor], mask_value: float = 0.0, inplace return x -def get_ij(x: Tensor, data: dict[str, Tensor], suffix: str = "") -> tuple[Tensor, Tensor]: - nb_mode = get_nb_mode(data) +def get_ij(x: Tensor, data: dict[str, Tensor], suffix: str = "", compile_nb_mode: int = -1) -> tuple[Tensor, Tensor]: + """Get i and j tensor views for pairwise operations. + + Args: + x: Input tensor + data: Data dictionary + suffix: Suffix for nbmat key (e.g., "_lr") + compile_nb_mode: If >= 0, use fixed nb_mode (compile mode) to avoid .item() call + + Returns: + Tuple of (x_i, x_j) tensors + """ + nb_mode = compile_nb_mode if compile_nb_mode >= 0 else get_nb_mode(data) if nb_mode == 0: x_i = x.unsqueeze(2) x_j = x.unsqueeze(1) @@ -126,8 +170,18 @@ def get_ij(x: Tensor, data: dict[str, Tensor], suffix: str = "") -> tuple[Tensor return x_i, x_j -def mol_sum(x: Tensor, data: dict[str, Tensor]) -> Tensor: - nb_mode = get_nb_mode(data) +def mol_sum(x: Tensor, data: dict[str, Tensor], compile_nb_mode: int = -1) -> Tensor: + """Sum over atoms in each molecule. + + Args: + x: Input tensor + data: Data dictionary + compile_nb_mode: If >= 0, use fixed nb_mode (compile mode) to avoid .item() call + + Returns: + Summed tensor + """ + nb_mode = compile_nb_mode if compile_nb_mode >= 0 else get_nb_mode(data) if nb_mode in (0, 2): res = x.sum(dim=1) elif nb_mode == 1: diff --git a/aimnet/ops.py b/aimnet/ops.py index 53c0714..55d642d 100644 --- a/aimnet/ops.py +++ b/aimnet/ops.py @@ -6,21 +6,23 @@ from aimnet import nbops -def lazy_calc_dij_lr(data: dict[str, Tensor]) -> dict[str, Tensor]: +def lazy_calc_dij_lr(data: dict[str, Tensor], compile_nb_mode: int = -1) -> dict[str, Tensor]: if "d_ij_lr" not in data: - nb_mode = nbops.get_nb_mode(data) + nb_mode = compile_nb_mode if compile_nb_mode >= 0 else nbops.get_nb_mode(data) if nb_mode == 0: data["d_ij_lr"] = data["d_ij"] else: - data["d_ij_lr"] = calc_distances(data, suffix="_lr")[0] + data["d_ij_lr"] = calc_distances(data, suffix="_lr", compile_nb_mode=compile_nb_mode)[0] return data -def calc_distances(data: dict[str, Tensor], suffix: str = "", pad_value: float = 1.0) -> tuple[Tensor, Tensor]: - coord_i, coord_j = nbops.get_ij(data["coord"], data, suffix) +def calc_distances( + data: dict[str, Tensor], suffix: str = "", pad_value: float = 1.0, compile_nb_mode: int = -1 +) -> tuple[Tensor, Tensor]: + coord_i, coord_j = nbops.get_ij(data["coord"], data, suffix, compile_nb_mode=compile_nb_mode) if f"shifts{suffix}" in data: assert "cell" in data, "cell is required if shifts are provided" - nb_mode = nbops.get_nb_mode(data) + nb_mode = compile_nb_mode if compile_nb_mode >= 0 else nbops.get_nb_mode(data) if nb_mode == 2: shifts = torch.einsum("bnmd,bdh->bnmh", data[f"shifts{suffix}"], data["cell"]) else: @@ -50,6 +52,23 @@ def cosine_cutoff(d_ij: Tensor, rc: float) -> Tensor: return fc +def cosine_cutoff_tensor(d_ij: Tensor, rc: Tensor) -> Tensor: + """Cosine cutoff function with tensor cutoff (for CUDA graphs compatibility). + + Unlike cosine_cutoff(), this version accepts a tensor cutoff value, + which is necessary for torch.compile with CUDA graphs since control + flow based on tensor values breaks graph capture. + + Args: + d_ij: Distance tensor + rc: Cutoff radius as Tensor + + Returns: + Cutoff values with 0 for distances >= rc + """ + return torch.where(d_ij < rc, 0.5 * (torch.cos(d_ij * (torch.pi / rc)) + 1.0), torch.zeros_like(d_ij)) + + def exp_cutoff(d: Tensor, rc: Tensor) -> Tensor: fc = torch.exp(-1.0 / (1.0 - (d / rc).clamp(0, 1.0 - 1e-6).pow(2))) / 0.36787944117144233 return fc @@ -66,17 +85,18 @@ def nse( f_u: Tensor, data: dict[str, Tensor], epsilon: float = 1.0e-6, + compile_nb_mode: int = -1, ) -> Tensor: # Q and q_u and f_u must have last dimension size 1 or 2 - F_u = nbops.mol_sum(f_u, data) + F_u = nbops.mol_sum(f_u, data, compile_nb_mode=compile_nb_mode) if epsilon > 0: F_u = F_u + epsilon - Q_u = nbops.mol_sum(q_u, data) + Q_u = nbops.mol_sum(q_u, data, compile_nb_mode=compile_nb_mode) dQ = Q - Q_u # for loss data["_dQ"] = dQ - nb_mode = nbops.get_nb_mode(data) + nb_mode = compile_nb_mode if compile_nb_mode >= 0 else nbops.get_nb_mode(data) if nb_mode in (0, 2): F_u = F_u.unsqueeze(-2) dQ = dQ.unsqueeze(-2) diff --git a/aimnet/py.typed b/aimnet/py.typed index 59c0a91..92a3b04 100644 --- a/aimnet/py.typed +++ b/aimnet/py.typed @@ -1,3 +1,2 @@ # PEP 561 marker file # This file indicates that the package supports type checking - diff --git a/docs/api/calculators.md b/docs/api/calculators.md index ec7c0cd..c5efc1d 100644 --- a/docs/api/calculators.md +++ b/docs/api/calculators.md @@ -13,7 +13,7 @@ The core calculator for running AIMNet2 inference. ASE calculator interface for AIMNet2. !!! note - Requires the `ase` extra: `pip install aimnet[ase]` +Requires the `ase` extra: `pip install aimnet[ase]` ::: aimnet.calculators.aimnet2ase.AIMNet2ASE @@ -22,7 +22,7 @@ ASE calculator interface for AIMNet2. PySisyphus calculator interface for AIMNet2. !!! note - Requires the `pysis` extra: `pip install aimnet[pysis]` +Requires the `pysis` extra: `pip install aimnet[pysis]` ::: aimnet.calculators.aimnet2pysis.AIMNet2Pysis @@ -31,4 +31,3 @@ PySisyphus calculator interface for AIMNet2. Utilities for loading pre-trained models. ::: aimnet.calculators.model_registry - diff --git a/docs/api/config.md b/docs/api/config.md index 9134fc8..e551f20 100644 --- a/docs/api/config.md +++ b/docs/api/config.md @@ -5,4 +5,3 @@ Configuration and model building utilities. ## Build Module ::: aimnet.config.build_module - diff --git a/docs/api/data.md b/docs/api/data.md index b9338fa..0c1a4c5 100644 --- a/docs/api/data.md +++ b/docs/api/data.md @@ -9,4 +9,3 @@ Dataset handling and data loading utilities. ## SizeGroupedSampler ::: aimnet.data.sgdataset.SizeGroupedSampler - diff --git a/docs/api/index.md b/docs/api/index.md index aa59f2f..5e6e063 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -24,4 +24,3 @@ The main entry points for using AIMNet2: ```bash aimnet --help ``` - diff --git a/docs/api/modules.md b/docs/api/modules.md index 8852787..0671be4 100644 --- a/docs/api/modules.md +++ b/docs/api/modules.md @@ -13,4 +13,3 @@ Neural network modules and model components. ## Base Classes ::: aimnet.models.base - diff --git a/examples/ase_md_compiled.py b/examples/ase_md_compiled.py new file mode 100644 index 0000000..796249e --- /dev/null +++ b/examples/ase_md_compiled.py @@ -0,0 +1,109 @@ +"""Example: Molecular dynamics with torch.compile and CUDA graphs. + +This example demonstrates the speedup from using torch.compile with CUDA graphs +for molecular dynamics simulations. On a modern GPU, compile mode can provide +~5x speedup for small molecules (76s -> 15s for 10k MD steps on caffeine). + +Usage: + python ase_md_compiled.py # Normal mode + python ase_md_compiled.py --compile # Compile mode with CUDA graphs + +Requirements: + - CUDA GPU + - ASE (pip install aimnet[ase]) +""" + +import argparse +import os +from time import perf_counter + +import torch + + +def torch_show_device_info(): + print(f"Torch version: {torch.__version__}") + if torch.cuda.is_available(): + print(f"CUDA available, version {torch.version.cuda}, device: {torch.cuda.get_device_name()}") # type: ignore + else: + print("CUDA not available") + + +def run_md(atoms, steps=1000, timestep=0.5, temperature=300): + """Run molecular dynamics simulation.""" + from ase import units + from ase.md.langevin import Langevin + + # Setup Langevin dynamics + dyn = Langevin( + atoms, + timestep * units.fs, + temperature_K=temperature, + friction=0.01 / units.fs, + ) + + # Run dynamics + t0 = perf_counter() + dyn.run(steps) + t1 = perf_counter() + + return t1 - t0, dyn.nsteps + + +def main(): + parser = argparse.ArgumentParser(description="AIMNet2 MD with torch.compile") + parser.add_argument("--compile", action="store_true", help="Enable torch.compile with CUDA graphs") + parser.add_argument("--model", type=str, default="aimnet2", help="Model name") + parser.add_argument("--steps", type=int, default=1000, help="Number of MD steps") + parser.add_argument("--warmup", type=int, default=10, help="Warmup steps (for compilation)") + args = parser.parse_args() + + # Check CUDA availability for compile mode + if args.compile and not torch.cuda.is_available(): + print("Error: --compile requires CUDA") + return + + torch_show_device_info() + print() + + # Load molecule + import ase.io + + from aimnet.calculators import AIMNet2ASE + + xyzfile = os.path.join(os.path.dirname(__file__), "..", "tests", "data", "caffeine.xyz") + atoms = ase.io.read(xyzfile) + + print(f"Molecule: {len(atoms)} atoms") + print(f"Model: {args.model}") + print(f"Compile mode: {args.compile}") + print() + + # Create calculator + calc = AIMNet2ASE(args.model, compile_mode=args.compile) + atoms.calc = calc + + # Warmup (especially important for compile mode to build CUDA graphs) + print(f"Running {args.warmup} warmup steps...") + warmup_time, _ = run_md(atoms, steps=args.warmup) + print(f"Warmup completed in {warmup_time:.2f}s") + + if args.compile: + print("(First run includes torch.compile compilation time)") + print() + + # Main MD run + print(f"Running {args.steps} MD steps...") + elapsed, nsteps = run_md(atoms, steps=args.steps) + + print("\nResults:") + print(f" Total time: {elapsed:.2f}s") + print(f" Time per step: {elapsed / nsteps * 1000:.2f}ms") + print(f" Steps per second: {nsteps / elapsed:.1f}") + + # Get final energy + energy = atoms.get_potential_energy() + print(f" Final energy: {energy:.4f} eV") + + +if __name__ == "__main__": + main() diff --git a/tests/test_compile.py b/tests/test_compile.py new file mode 100644 index 0000000..9da972d --- /dev/null +++ b/tests/test_compile.py @@ -0,0 +1,212 @@ +"""Tests for torch.compile with CUDA graphs support.""" + +import os + +import numpy as np +import pytest +import torch + +from aimnet.calculators import AIMNet2Calculator + +file = os.path.join(os.path.dirname(__file__), "data", "caffeine.xyz") + + +def load_mol(filepath): + """Load molecule from XYZ file.""" + pytest.importorskip("ase", reason="ASE not installed. Install with: pip install aimnet[ase]") + import ase.io + + atoms = ase.io.read(filepath) + data = { + "coord": atoms.get_positions(), # type: ignore + "numbers": atoms.get_atomic_numbers(), # type: ignore + "charge": 0.0, + } + return data + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_requires_cuda(): + """Test that compile_mode raises error without CUDA.""" + if torch.cuda.is_available(): + pytest.skip("CUDA is available, skipping non-CUDA test") + with pytest.raises(ValueError, match="compile_mode requires CUDA"): + AIMNet2Calculator("aimnet2", compile_mode=True) + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_requires_model_name(): + """Test that compile_mode requires model name (str), not nn.Module.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + # Load a regular model first + calc = AIMNet2Calculator("aimnet2") + # Try to use compile_mode with the model instance + with pytest.raises(ValueError, match="compile_mode requires model name"): + AIMNet2Calculator(calc.model, compile_mode=True) + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_basic(): + """Test compile mode produces results.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + calc = AIMNet2Calculator("aimnet2", compile_mode=True) + data = load_mol(file) + res = calc(data) + assert "energy" in res + assert "charges" in res + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_energy_consistency(): + """Test compile mode produces same energy as normal mode (within tolerance).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + data = load_mol(file) + + # Normal mode + calc_normal = AIMNet2Calculator("aimnet2") + res_normal = calc_normal(data) + + # Compile mode + calc_compile = AIMNet2Calculator("aimnet2", compile_mode=True) + res_compile = calc_compile(data) + + # Results should be very close (float32 tolerance) + np.testing.assert_allclose( + res_normal["energy"].cpu().numpy(), + res_compile["energy"].cpu().numpy(), + rtol=1e-4, + atol=1e-5, + ) + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_forces(): + """Test compile mode with forces calculation.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + calc = AIMNet2Calculator("aimnet2", compile_mode=True) + data = load_mol(file) + res = calc(data, forces=True) + assert "energy" in res + assert "forces" in res + assert res["forces"].shape[0] == 24 # caffeine has 24 atoms + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_forces_consistency(): + """Test compile mode produces same forces as normal mode.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + data = load_mol(file) + + # Normal mode + calc_normal = AIMNet2Calculator("aimnet2") + res_normal = calc_normal(data, forces=True) + + # Compile mode + calc_compile = AIMNet2Calculator("aimnet2", compile_mode=True) + res_compile = calc_compile(data, forces=True) + + # Forces should be close + np.testing.assert_allclose( + res_normal["forces"].cpu().numpy(), + res_compile["forces"].cpu().numpy(), + rtol=1e-3, + atol=1e-4, + ) + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_multiple_calls(): + """Test CUDA graph reuse across multiple calls.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + calc = AIMNet2Calculator("aimnet2", compile_mode=True) + data = load_mol(file) + + # First call (warm-up / compilation) + res1 = calc(data, forces=True) + + # Subsequent calls should use cached CUDA graph + res2 = calc(data, forces=True) + res3 = calc(data, forces=True) + + # All results should be identical + np.testing.assert_array_equal( + res1["energy"].cpu().numpy(), + res2["energy"].cpu().numpy(), + ) + np.testing.assert_array_equal( + res2["energy"].cpu().numpy(), + res3["energy"].cpu().numpy(), + ) + + +@pytest.mark.gpu +@pytest.mark.ase +@pytest.mark.parametrize("model", ["aimnet2", "aimnet2_b973c"]) +def test_compile_mode_different_models(model): + """Test compile mode with different models.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + calc = AIMNet2Calculator(model, compile_mode=True) + data = load_mol(file) + res = calc(data) + assert "energy" in res + assert "charges" in res + + +@pytest.mark.gpu +@pytest.mark.ase +def test_compile_mode_pbc_not_supported(): + """Test that PBC raises error in compile mode.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + calc = AIMNet2Calculator("aimnet2", compile_mode=True) + data = load_mol(file) + data["cell"] = np.eye(3) * 10.0 # Add cell for PBC + + with pytest.raises(NotImplementedError, match="PBC is not supported in compile mode"): + calc(data) + + +@pytest.mark.gpu +@pytest.mark.ase +def test_ase_calculator_compile_mode(): + """Test ASE calculator with compile mode.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + pytest.importorskip("ase", reason="ASE not installed") + from ase.io import read + + from aimnet.calculators import AIMNet2ASE + + atoms = read(file) + atoms.calc = AIMNet2ASE("aimnet2", compile_mode=True) + + e = atoms.get_potential_energy() + assert isinstance(e, float) + + f = atoms.get_forces() + assert f.shape == (24, 3) + + q = atoms.get_charges() + assert q.shape == (24,)