From f3d36f898508926ac59d9d03e9efbc65ef9d44e7 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 16 Feb 2026 21:02:50 +0100 Subject: [PATCH 1/5] feat(pet): add neighbor_atom_indices to systems_to_batch Return the neighbor atom indices in NEF format from systems_to_batch. This tensor maps each edge to its neighbor atom index, needed to scatter edge-level force gradients to per-atom forces in the compiled forward path. --- src/metatrain/pet/model.py | 1 + src/metatrain/pet/modules/structures.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 757e778140..f9b7d716ac 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -416,6 +416,7 @@ def forward( reverse_neighbor_index, cutoff_factors, system_indices, + _neighbor_atom_indices, sample_labels, ) = systems_to_batch( systems, diff --git a/src/metatrain/pet/modules/structures.py b/src/metatrain/pet/modules/structures.py index 8438a7c67e..291556f9d8 100644 --- a/src/metatrain/pet/modules/structures.py +++ b/src/metatrain/pet/modules/structures.py @@ -154,6 +154,7 @@ def systems_to_batch( torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, Labels, ]: """ @@ -181,6 +182,8 @@ def systems_to_batch( - `reverse_neighbor_index`: The reversed neighbor list for each central atom - `cutoff_factors`: The cutoff function values for each edge - `system_indices`: The system index for each atom in the batch + - `neighbor_atom_indices`: The atom index of each neighbor in NEF format, + used to scatter edge-level gradients to per-atom forces in the compiled path - `sample_labels`: Labels indicating the system and atom indices for each atom """ @@ -296,6 +299,7 @@ def systems_to_batch( reversed_neighbor_list = compute_reversed_neighbor_list( nef_indices, corresponding_edges, nef_mask ) + neighbor_atom_indices = edge_array_to_nef(neighbors, nef_indices, nef_mask, 0.0) neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) # Here, we compute the array that allows indexing into a flattened @@ -320,5 +324,6 @@ def systems_to_batch( reverse_neighbor_index, cutoff_factors, system_indices, + neighbor_atom_indices, sample_labels, ) From fda7130cc763df78b01575fc9b67dd0378be84c4 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 16 Feb 2026 21:03:17 +0100 Subject: [PATCH 2/5] feat(pet): add pure-tensor _forward_from_batch method Add a pure-tensor forward path that bypasses metatensor wrapping, returning Dict[str, Dict[str, Tensor]] (target -> block_key -> tensor). Always uses SDPA attention since forces will be computed via autograd.grad(create_graph=False) in the compiled graph, avoiding double backward entirely. --- src/metatrain/pet/model.py | 72 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index f9b7d716ac..9ecfac551b 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -571,6 +571,78 @@ def forward( return return_dict + def _forward_from_batch( + self, + element_indices_nodes: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + reverse_neighbor_index: torch.Tensor, + cutoff_factors: torch.Tensor, + ) -> Dict[str, Dict[str, torch.Tensor]]: + """ + Pure-tensor forward pass for FX compilation. + + Takes batch tensors and returns raw per-atom predictions as nested + dictionaries (target_name -> block_key -> tensor). Always uses SDPA + attention (no manual attention needed since forces are computed via + ``autograd.grad(create_graph=False)`` in the compiled graph, avoiding + double backward). + + :param element_indices_nodes: Atomic species of central atoms [n_atoms]. + :param element_indices_neighbors: Atomic species of neighbors + [n_atoms, max_neighbors]. + :param edge_vectors: Edge vectors [n_atoms, max_neighbors, 3]. + :param edge_distances: Edge distances [n_atoms, max_neighbors]. + :param padding_mask: Boolean mask for real neighbors + [n_atoms, max_neighbors]. + :param reverse_neighbor_index: Reversed neighbor index for message + passing [n_atoms, max_neighbors]. + :param cutoff_factors: Cutoff function values [n_atoms, max_neighbors]. + :return: Nested dict mapping target_name -> block_key -> per-atom + prediction tensor. + """ + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + + # Always use SDPA (no double backward in compiled path) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, use_manual_attention=False + ) + + node_ll_dict, edge_ll_dict = self._calculate_last_layer_features( + node_features_list, edge_features_list + ) + + outputs_all: Dict[str, ModelOutput] = { + name: ModelOutput(per_atom=True) for name in self.target_names + } + node_preds, edge_preds = self._calculate_atomic_predictions( + node_ll_dict, edge_ll_dict, padding_mask, cutoff_factors, outputs_all + ) + + # Sum across GNN layers for each target/block + results: Dict[str, Dict[str, torch.Tensor]] = {} + for target_name in self.target_names: + block_results: Dict[str, torch.Tensor] = {} + node_layers = node_preds[target_name] + edge_layers = edge_preds[target_name] + for j, key in enumerate(self.output_shapes[target_name]): + total = node_layers[0][j] + edge_layers[0][j] + for i in range(1, len(node_layers)): + total = total + node_layers[i][j] + edge_layers[i][j] + block_results[key] = total + results[target_name] = block_results + return results + def _calculate_features( self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: From 3c4e684294b82d851111fecd915de54fd2b73f07 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 16 Feb 2026 21:03:44 +0100 Subject: [PATCH 3/5] feat(pet): full-graph FX compilation for training Trace the entire PET forward pass (including force/stress computation via autograd.grad) into a single FX graph using make_fx, then compile with torch.compile(dynamic=True, fullgraph=True). This gives maximum kernel fusion, zero compiled/eager boundary crossings, and always uses SDPA since forces use create_graph=False (single backward, no double backward). Key components: - modules/compile.py: _PETBatchForward wrapper, _make_pet_compiled_forward traceable function with NamedMemberAccessor param swapping, and compile_pet_model orchestrator - trainer.py: _wrap_compiled_output converts compiled outputs to Dict[str, TensorMap], training loop branches between compiled and eager paths - DecomposedSiLU (x * sigmoid(x)) replaces nn.SiLU before tracing since inductor can't differentiate silu_backward nodes --- pyproject.toml | 7 +- src/metatrain/pet/checkpoints.py | 9 + src/metatrain/pet/documentation.py | 11 + src/metatrain/pet/model.py | 9 +- src/metatrain/pet/modules/compile.py | 293 +++++++++++++++++++++++++ src/metatrain/pet/modules/utilities.py | 31 +++ src/metatrain/pet/trainer.py | 278 +++++++++++++++++++++-- 7 files changed, 613 insertions(+), 25 deletions(-) create mode 100644 src/metatrain/pet/modules/compile.py diff --git a/pyproject.toml b/pyproject.toml index 4025b408c3..18ea6df8cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,6 +163,7 @@ filterwarnings = [ "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", "ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning", + "ignore:`torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.trace_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.trace` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", # PyTorch does not want these, but mypy requires them @@ -177,7 +178,11 @@ filterwarnings = [ # Multi-threaded tests clash with multi-process data-loading "ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning", # MACE warning with newer versions of pytorch (because they use e3nn==0.4.4) - "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning" + "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning", + # compiled_autograd + non-leaf tensors: dynamo incorrectly converts this C++ + # warning to an error when tracing backward graphs through non-leaf tensors + # (e.g. edge_distances computed from edge_vectors via sqrt) + "ignore:The .grad attribute of a Tensor that is not a leaf Tensor is being accessed:UserWarning" ] addopts = ["-p", "mtt_plugin"] pythonpath = "src/metatrain/utils/testing" diff --git a/src/metatrain/pet/checkpoints.py b/src/metatrain/pet/checkpoints.py index 020b1e6590..f09c9afbb3 100644 --- a/src/metatrain/pet/checkpoints.py +++ b/src/metatrain/pet/checkpoints.py @@ -420,3 +420,12 @@ def trainer_update_v11_v12(checkpoint: dict) -> None: :param checkpoint: The checkpoint to update. """ checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] + + +def trainer_update_v12_v13(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 12 to version 13. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["compile"] = False diff --git a/src/metatrain/pet/documentation.py b/src/metatrain/pet/documentation.py index 8eaec9c03e..629e02c0ce 100644 --- a/src/metatrain/pet/documentation.py +++ b/src/metatrain/pet/documentation.py @@ -257,3 +257,14 @@ class TrainerHypers(TypedDict): See :ref:`label_fine_tuning_concept` for more details. """ + compile: bool = False + """Whether to use full-graph FX compilation during training. + + When enabled, the entire PET model (including force/stress computation via + ``autograd.grad``) is traced into a single FX graph using ``make_fx`` and + then compiled with ``torch.compile(dynamic=True, fullgraph=True)``. This + gives maximum kernel fusion, zero compiled/eager boundary crossings, and + always uses ``scaled_dot_product_attention`` (SDPA). Expect a one-time + compilation cost at the start of training, followed by speedups on every + subsequent step. + """ diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 9ecfac551b..c5cb7e5b5a 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -1413,7 +1413,12 @@ def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: return checkpoint def get_checkpoint(self) -> Dict: - model_state_dict = self.state_dict() + # Get state dict, handling compiled modules by removing _orig_mod prefix + state_dict = { + k.replace("._orig_mod", ""): v + for k, v in self.state_dict().items() + } + model_state_dict = dict(state_dict) model_state_dict["finetune_config"] = self.finetune_config checkpoint = { "architecture_name": "pet", @@ -1426,7 +1431,7 @@ def get_checkpoint(self) -> Dict: "epoch": None, "best_epoch": None, "model_state_dict": model_state_dict, - "best_model_state_dict": self.state_dict(), + "best_model_state_dict": state_dict, } return checkpoint diff --git a/src/metatrain/pet/modules/compile.py b/src/metatrain/pet/modules/compile.py new file mode 100644 index 0000000000..56287e1c44 --- /dev/null +++ b/src/metatrain/pet/modules/compile.py @@ -0,0 +1,293 @@ +"""Full-graph FX compilation for PET. + +Traces the entire PET forward pass (including force/stress computation via +``autograd.grad``) into a single FX graph, then compiles it with +``torch.compile(dynamic=True, fullgraph=True)``. This gives maximum kernel +fusion, zero compiled/eager boundary crossings, and always uses SDPA +(``scaled_dot_product_attention``) since forces use +``create_graph=False`` (no double backward). +""" + +import logging +from typing import Dict, List, Optional, Tuple + +import torch +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + +from .utilities import replace_silu_modules + + +class _PETBatchForward(torch.nn.Module): + """Thin wrapper whose ``forward()`` delegates to ``pet._forward_from_batch``. + + PET is registered as a submodule so its parameters/buffers are visible + to ``functional_call`` / ``NamedMemberAccessor``. + """ + + def __init__(self, pet: torch.nn.Module) -> None: + super().__init__() + self.pet = pet + + def forward( + self, + element_indices_nodes: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + reverse_neighbor_index: torch.Tensor, + cutoff_factors: torch.Tensor, + ) -> Dict[str, Dict[str, torch.Tensor]]: + return self.pet._forward_from_batch( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + ) + + +def _make_pet_compiled_forward( + batch_model: _PETBatchForward, + param_names: List[str], + buffer_names: List[str], + target_names: List[str], + output_shapes: Dict[str, Dict[str, List[int]]], + compute_forces: bool, + compute_stress: bool, +): + """Build the traceable forward function for ``make_fx``. + + The returned function accepts all batch tensors and the model's + parameters/buffers as positional arguments (required by + ``make_fx`` with ``functional_call``). It returns + ``(per_structure_preds, forces, stress, raw_predictions)``. + """ + n_params = len(param_names) + accessor = NamedMemberAccessor(batch_model) + + # Identify which target is the energy target (quantity == "energy") + # For force/stress we need to aggregate per-atom energy to per-structure. + energy_target_name: Optional[str] = None + energy_block_key: Optional[str] = None + pet = batch_model.pet + for tname in target_names: + if hasattr(pet, "outputs") and tname in pet.outputs: + if pet.outputs[tname].quantity == "energy": + energy_target_name = tname + # First block key for this target + energy_block_key = next(iter(output_shapes[tname])) + break + + if (compute_forces or compute_stress) and energy_target_name is None: + raise ValueError( + "Force/stress compilation requested but no energy target found." + ) + + def forward_fn( + edge_vectors, + element_indices_nodes, + element_indices_neighbors, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + neighbor_atom_indices, + n_structures, + *params_and_buffers, + ): + # Swap in the provided params/buffers via NamedMemberAccessor + params_buffers = {} + for i, name in enumerate(param_names): + params_buffers[name] = params_and_buffers[i] + for i, name in enumerate(buffer_names): + params_buffers[name] = params_and_buffers[n_params + i] + + orig_values, _ = accessor.swap_tensors_dict( + params_buffers, allow_missing=True + ) + + # Compute edge_distances inside compiled graph (differentiable) + edge_distances = torch.sqrt((edge_vectors**2).sum(-1) + 1e-15) + + raw_predictions = batch_model( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + ) + + # Restore original params/buffers + accessor.swap_tensors_dict(orig_values, allow_missing=True) + + # Aggregate per-atom predictions to per-structure for the energy target + n_atoms = edge_vectors.shape[0] + # +1 for padding structure index (scatter needs valid indices) + n_struct = n_structures + 1 + + energy: Optional[torch.Tensor] = None + forces: Optional[torch.Tensor] = None + stress: Optional[torch.Tensor] = None + + if energy_target_name is not None and energy_block_key is not None: + per_atom_energy = raw_predictions[energy_target_name][energy_block_key] + energy = torch.zeros( + n_struct, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + energy.scatter_add_( + 0, system_indices, per_atom_energy.squeeze(-1) + ) + + if (compute_forces or compute_stress) and energy is not None: + (dE_dR,) = torch.autograd.grad( + energy[:n_structures].sum(), + edge_vectors, + create_graph=False, + ) + dE_dR = dE_dR * padding_mask[:, :, None].float() + + if compute_forces: + # d(E)/d(pos[i]): + # as center: -sum_j dE_dR[i, j] + # as neighbor: +sum_{(k,j): neighbor_atom=i} dE_dR[k, j] + grad_as_center = -dE_dR.sum(dim=1) # [n_atoms, 3] + flat_dE = dE_dR.reshape(-1, 3) + flat_idx = neighbor_atom_indices.reshape(-1, 1).expand(-1, 3).long() + grad_as_neighbor = torch.zeros( + n_atoms, 3, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + grad_as_neighbor.scatter_add_(0, flat_idx, flat_dE) + forces = grad_as_center + grad_as_neighbor + + if compute_stress: + # Virial: sigma = (1/V) sum r otimes (dE/dr) + virial_per_atom = torch.einsum("ema,emb->eab", edge_vectors, dE_dR) + stress_buf = torch.zeros( + n_struct, 3, 3, + dtype=edge_vectors.dtype, device=edge_vectors.device, + ) + stress_buf.scatter_add_( + 0, + system_indices[:, None, None].expand(-1, 3, 3), + virial_per_atom, + ) + stress = stress_buf[:n_structures] + + if energy is not None: + energy = energy[:n_structures] + + return energy, forces, stress, raw_predictions + + return forward_fn + + +def compile_pet_model( + model: torch.nn.Module, + train_dataloader, + compute_forces: bool, + compute_stress: bool, +) -> Tuple[torch.nn.Module, List[str], List[str]]: + """Trace and compile the PET model as a single FX graph. + + :param model: The PET model instance. + :param train_dataloader: A dataloader to get a sample batch for tracing. + :param compute_forces: Whether force computation is included. + :param compute_stress: Whether stress computation is included. + :return: Tuple of (compiled_module, param_names, buffer_names). + """ + from torch.fx.experimental.proxy_tensor import make_fx + + from metatrain.utils.data import unpack_batch + from metatrain.utils.transfer import batch_to + + from ..modules.structures import systems_to_batch + + batch_model = _PETBatchForward(model) + replace_silu_modules(batch_model) + + params = dict(batch_model.named_parameters()) + buffers = dict(batch_model.named_buffers()) + param_names = list(params.keys()) + buffer_names = list(buffers.keys()) + + forward_fn = _make_pet_compiled_forward( + batch_model, + param_names, + buffer_names, + model.target_names, + model.output_shapes, + compute_forces, + compute_stress, + ) + + # Get a sample batch for tracing + batch = next(iter(train_dataloader)) + systems, _targets, _extra_data = unpack_batch(batch) + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + systems, _, _ = batch_to(systems, {}, {}, dtype=dtype, device=device) + + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + neighbor_atom_indices, + _sample_labels, + ) = systems_to_batch( + systems, + model.requested_nl, + model.atomic_types, + model.species_to_species_index, + model.cutoff_function, + model.cutoff_width, + model.num_neighbors_adaptive, + ) + + n_structures = int(system_indices.max().item()) + 1 + + # edge_vectors needs grad for force tracing + tracing_edge_vectors = edge_vectors.clone().requires_grad_(True) + + logging.info( + "Tracing PET model with make_fx (symbolic tracing)..." + ) + + old_duck = torch.fx.experimental._config.use_duck_shape + torch.fx.experimental._config.use_duck_shape = False + try: + fx_graph = make_fx( + forward_fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )( + tracing_edge_vectors, + element_indices_nodes, + element_indices_neighbors, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + neighbor_atom_indices, + n_structures, + *list(params.values()), + *list(buffers.values()), + ) + finally: + torch.fx.experimental._config.use_duck_shape = old_duck + + logging.info("Compiling traced FX graph with torch.compile...") + compiled = torch.compile( + fx_graph, dynamic=True, fullgraph=True + ) + + return compiled, param_names, buffer_names diff --git a/src/metatrain/pet/modules/utilities.py b/src/metatrain/pet/modules/utilities.py index 5795c5aed8..2414f2f1ca 100644 --- a/src/metatrain/pet/modules/utilities.py +++ b/src/metatrain/pet/modules/utilities.py @@ -52,6 +52,37 @@ def cutoff_func_cosine( return f +class DecomposedSiLU(torch.nn.Module): + """SiLU activation implemented as ``x * sigmoid(x)``. + + Unlike ``torch.nn.SiLU``, this decomposes into primitive ops so that + ``make_fx`` produces a backward graph without ``silu_backward`` nodes. + This is needed for ``torch.compile(inductor)`` to differentiate through + the inlined backward when using the FX compilation path for force training. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +def replace_silu_modules(module: torch.nn.Module) -> None: + """Replace all ``torch.nn.SiLU`` instances with :class:`DecomposedSiLU`. + + Recurses through the module tree, including inside ``nn.Sequential``. + """ + for name, child in module.named_children(): + if isinstance(child, torch.nn.SiLU): + setattr(module, name, DecomposedSiLU()) + elif isinstance(child, torch.nn.Sequential): + for i, layer in enumerate(child): + if isinstance(layer, torch.nn.SiLU): + child[i] = DecomposedSiLU() + else: + replace_silu_modules(layer) + else: + replace_silu_modules(child) + + class DummyModule(torch.nn.Module): """Dummy torch module to make torchscript happy. This model should never be run""" diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index 8194adc7de..e7264601ca 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -37,10 +37,145 @@ from metatrain.utils.scaler import get_remove_scale_transform from metatrain.utils.transfer import batch_to +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import System + from . import checkpoints from .documentation import TrainerHypers from .model import PET from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import systems_to_batch + + +def _wrap_compiled_output( + energy: torch.Tensor, + forces: torch.Tensor, + stress: torch.Tensor, + raw_predictions: Dict[str, Dict[str, torch.Tensor]], + model: PET, + systems: List[System], + sample_labels: Labels, + system_indices: torch.Tensor, + train_targets: Dict, +) -> Dict[str, TensorMap]: + """Convert compiled function outputs to Dict[str, TensorMap]. + + Produces the same format as ``evaluate_model`` so the loss function + and metric accumulators work unchanged. + """ + from metatrain.utils.sum_over_atoms import sum_over_atoms + + device = system_indices.device + predictions: Dict[str, TensorMap] = {} + + # Identify the energy target + energy_target_name = None + for tname in model.target_names: + if tname in model.outputs and model.outputs[tname].quantity == "energy": + energy_target_name = tname + break + + # Build energy TensorMap (per-structure) with optional gradient blocks + if energy_target_name is not None and energy is not None: + n_structures = energy.shape[0] + energy_block = TensorBlock( + values=energy.unsqueeze(-1), + samples=Labels( + "system", + torch.arange( + n_structures, device=device, dtype=torch.int32 + ).unsqueeze(-1), + assume_unique=True, + ), + components=[], + properties=Labels( + "energy", torch.tensor([[0]], device=device) + ), + ) + + if forces is not None: + # Position gradient block: samples are ["sample", "atom"] + # matching evaluate_model's _position_gradients_to_block format + grad_samples = Labels( + names=["sample", "atom"], + values=sample_labels.values.to(torch.int32), + assume_unique=True, + ).to(device) + xyz_labels = Labels( + "xyz", torch.tensor([[0], [1], [2]], device=device) + ) + forces_block = TensorBlock( + values=forces.unsqueeze(-1), + samples=grad_samples, + components=[xyz_labels], + properties=Labels( + "energy", torch.tensor([[0]], device=device) + ), + ) + energy_block.add_gradient("positions", forces_block) + + if stress is not None: + stress_samples = Labels( + "sample", + torch.arange( + n_structures, device=device, dtype=torch.int32 + ).unsqueeze(-1), + assume_unique=True, + ) + xyz1 = Labels( + "xyz_1", torch.tensor([[0], [1], [2]], device=device) + ) + xyz2 = Labels( + "xyz_2", torch.tensor([[0], [1], [2]], device=device) + ) + stress_block = TensorBlock( + values=stress.unsqueeze(-1), + samples=stress_samples, + components=[xyz1, xyz2], + properties=Labels( + "energy", torch.tensor([[0]], device=device) + ), + ) + energy_block.add_gradient("strain", stress_block) + + predictions[energy_target_name] = TensorMap( + keys=model.single_label, + blocks=[energy_block], + ) + + # Non-energy targets: wrap per-atom raw predictions into TensorMaps + for target_name in model.target_names: + if target_name == energy_target_name: + continue + if target_name not in raw_predictions: + continue + if target_name not in train_targets: + continue + + target_preds = raw_predictions[target_name] + blocks = [] + for key, shape, components, properties in zip( + model.output_shapes[target_name].keys(), + model.output_shapes[target_name].values(), + model.component_labels[target_name], + model.property_labels[target_name], + strict=True, + ): + values = target_preds[key].reshape([-1] + shape) + block = TensorBlock( + values=values, + samples=sample_labels, + components=components, + properties=properties, + ) + blocks.append(block) + + tmap = TensorMap(keys=model.key_labels[target_name], blocks=blocks) + if not train_targets[target_name].per_atom: + tmap = sum_over_atoms(tmap) + predictions[target_name] = tmap + + return predictions def get_scheduler( @@ -77,7 +212,7 @@ def lr_lambda(current_step: int) -> float: class Trainer(TrainerInterface[TrainerHypers]): - __checkpoint_version__ = 12 + __checkpoint_version__ = 13 def __init__(self, hypers: TrainerHypers) -> None: super().__init__(hypers) @@ -160,6 +295,26 @@ def train( additive_model.to(dtype=torch.float64) model.scaler.to(dtype=torch.float64) + # torch.compile: full-graph FX compilation of the entire model + # (including force/stress computation via autograd.grad). + compile_enabled = self.hypers.get("compile", False) + has_gradients = any( + len(target_info.gradients) > 0 + for target_info in model.dataset_info.targets.values() + ) + has_strain_gradients = any( + "strain" in target_info.gradients + for target_info in model.dataset_info.targets.values() + ) + if compile_enabled: + torch._dynamo.reset() + if is_distributed: + logging.warning( + "torch.compile with DDP is not yet supported. " + "Disabling compilation for distributed training." + ) + compile_enabled = False + logging.info("Calculating composition weights") model.additive_models[0].train_model( # this is the composition model train_datasets, @@ -357,6 +512,23 @@ def train( # Log the initial learning rate: logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + # Full-graph FX compilation (after dataloaders are ready for tracing). + compiled_fn = None + if compile_enabled: + from .modules.compile import compile_pet_model + + compiled_fn, _compiled_param_names, _compiled_buffer_names = ( + compile_pet_model( + model, + train_dataloader, + has_gradients, + has_strain_gradients, + ) + ) + logging.info( + "FX compilation complete (will optimize on first call)" + ) + start_epoch = 0 if self.epoch is None else self.epoch + 1 # Train the model: @@ -389,27 +561,85 @@ def train( systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype, device=device ) - predictions = evaluate_model( - model, - systems, - {key: train_targets[key] for key in targets.keys()}, - is_training=True, - ) - # average by the number of atoms - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets, extra_data) + if compile_enabled and compiled_fn is not None: + # FX-compiled path: call systems_to_batch directly, + # run the compiled function, and wrap outputs. + ( + c_element_indices_nodes, + c_element_indices_neighbors, + c_edge_vectors, + _c_edge_distances, + c_padding_mask, + c_reverse_neighbor_index, + c_cutoff_factors, + c_system_indices, + c_neighbor_atom_indices, + c_sample_labels, + ) = systems_to_batch( + systems, + model.requested_nl, + model.atomic_types, + model.species_to_species_index, + model.cutoff_function, + model.cutoff_width, + model.num_neighbors_adaptive, + ) + if has_gradients: + c_edge_vectors = c_edge_vectors.requires_grad_(True) + n_structures = len(systems) + energy, forces, stress, raw_preds = compiled_fn( + c_edge_vectors, + c_element_indices_nodes, + c_element_indices_neighbors, + c_padding_mask, + c_reverse_neighbor_index, + c_cutoff_factors, + c_system_indices, + c_neighbor_atom_indices, + n_structures, + *list(model.parameters()), + *list(model.buffers()), + ) + predictions = _wrap_compiled_output( + energy, + forces, + stress, + raw_preds, + model, + systems, + c_sample_labels, + c_system_indices, + train_targets, + ) + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + train_loss_batch = loss_fn(predictions, targets, extra_data) + train_loss_batch.backward() + else: + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + train_loss_batch = loss_fn(predictions, targets, extra_data) - if is_distributed: - # make sure all parameters contribute to the gradient calculation - # to make torch DDP happy - for param in model.parameters(): - train_loss_batch += 0.0 * param.sum() + if is_distributed: + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() - train_loss_batch.backward() + train_loss_batch.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), self.hypers["grad_clip_norm"] ) @@ -538,9 +768,13 @@ def train( ) if val_metric < self.best_metric: self.best_metric = val_metric - self.best_model_state_dict = copy.deepcopy( - (model.module if is_distributed else model).state_dict() - ) + raw_state_dict = ( + model.module if is_distributed else model + ).state_dict() + self.best_model_state_dict = { + k.replace("._orig_mod", ""): v.clone() + for k, v in raw_state_dict.items() + } self.best_epoch = epoch self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) From 133bedc335b1c2fcd960f92f3040014e868a2e42 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 16 Feb 2026 21:04:06 +0100 Subject: [PATCH 4/5] test(pet): add tests for FX compilation and _forward_from_batch - test_forward_from_batch: verifies numerical equivalence between _forward_from_batch and the standard forward path - TestTrainingCompile: exercises the full FX compilation path during training (energy + forces), including restart and finetune scenarios - Existing CartesianTransformer compile unit tests retained --- .../checkpoints/model-v11_trainer-v13.ckpt.gz | Bin 0 -> 16997 bytes src/metatrain/pet/tests/test_compile.py | 226 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 src/metatrain/pet/tests/checkpoints/model-v11_trainer-v13.ckpt.gz create mode 100644 src/metatrain/pet/tests/test_compile.py diff --git a/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v13.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v13.ckpt.gz new file mode 100644 index 0000000000000000000000000000000000000000..c5c58fb9f2a6b0c3bb8c43f860ac6ed8c5779318 GIT binary patch literal 16997 zcmZ8{1yoe+8tx3;-HmifgVG>KE7INF77a6$w4?$eDN4ggH$zGa(jg(;Aw9(04gY`c zIrpvwJHGwK^FD8Ud$VAS!NU54k?e$mbn~#abLI0A6tsBr%F4;z?iFw?{Lse5^9^V( z6Vr!mg7Pd`laOBFKuN;NyOu>$QSPHXCk9cfW~C~{==%OU2K23ItrtYPT12|hk?81L z(fIPZit=(`VkM>TxtutQgEKQ#!jj${KV!46kgduA8Z*0>EgLQ8KQq%!>?0aTNLH4$ zbyjSV+g%fp1AHgz&vNpHo+wAy=WhReIreN^0Y`Im<9 zQ%EbCx^#B!R;ozrrv{qAlL=-SS4H(A!zQcik>GPy{-Y~3^=UMN53c6T)3k?&B(B!( zIVq@L35VHQE8!!bvFQXdHMa#~G=IwUw7m55!`-*wdG4D7OSv$4>mul7y0N6*pu`}H zR48e09<&5ct6jKcO%W* z>S)#Adq2Yt(w7EyhF4=awTyQMhEk|(bvC{`J)YSzY(0FSTrw)t#NBH6s3w-@Wvw&1 zNm-zm(D^HO-y9Ly?6QS6?>);vhpCm;loZM)_2$-)4!UyBdAe*Kn5ubgPr@z@)yhwP z{=T(=#Ip3En#_8cwQeBsBP0s~GpUrEwl3!fZ?)w(URwP99q^=ZB^vWm$~$)|`G@^d z1)J$7SF`;o=R=p)sf<($HFkd$29tPyQB9Ti@qepM8w`ZP45Wggbz#qcn60Mw`@6z# zzBeBJ2q6h*$b~=TV(ZR&uqZ-k=$Mk((ehE$=YnY8u2zujOT=q&u8R%!g!#k1*4xmh zm6}yMe<4aO(7H4pUBX3_Y5N=h8_u+{zQ8{;dQsT}ZPV(%!R4DX>slC{k=`RjwdiE_ zt*P-Rnz$`r?7YZ*M0cCW>14)UMNd!SGJnq-O+pb@et1~v#y#y4LcYa!v`XeXQ;VA$ z+kazT=UrLA8Shp{M8FIAKJ_Gk4c+zXwk2pF0^(Wy>?GCqRg3%bX)48s&?$l&f3wXt zE}PjR@XXtKaZk*a)RYet^73-07c=k6^Z&jobZSk!YysC*u(qY_rQo_H`vkD1)0OmQ zW4DAKW}URQABey5&~%V|t^6t#N32SEn(yyNGBlj|HMUmJoM;BW%p_LLDuLKrP8wMk znjb=YW=G|=2GKkdxOx<}p^=reC3>S_?vH)gPGqWnv* z7jH@VaSP+-O?m9={y-&PQ-7|bRFZR6_RPJRR_Xe~*9Rw_6Eo_Yrw9#vp9o9~bq9Zk zAFn+GE2rcod--@AuJHELs4Oem`pAFW2&C@qL}w2yzy!qWstbE@Bl2YfIr~-+CC#le zt#i9m%rSe6mYBBR#qGUYji*p%?A}8?LZ$6?;l8Cjn6pGC`+tlBEV(5MM{pvf#;2BC zQ7v}AOfzeJ$RqejZ|oSB*U!a0l-WyekQ?E$)X6=akwcB?(!b|LJ5%~{*VCZ+heg&3 z;ce8x`%vlWNE|O)-)E2M#1|?@e=B&U@DMDdtej&0ZclahLXWSrLVs$VzZJ7Z=?ZEoN6b*4*DDXabD{JejYlWKM+UnOg#y{^*1+D06=9N0!W_q$ zuF#B7;DU_gLjPadr5)S9#j$~T$7(%J($y~%@67` z=&r@s3MEA0&jzdu93(}GLk6fxGzLjhqaU>s7 zbK1JUSM2TiiH|cx8n%9HL4g)dfibOQw3|7TnMch1kQ-O^H(MgC${UkypO|%jt&NvJ z?Kh?6!LT)BeKFcdgAEJocf|*PEE!i>@xW)y<{j{v=~iae-q z|3)fy^~unX`Qj?~KdFxaHEjS1!T znXcC67U)WYqDiL7W?aG2E>yRt%N?)Zm1KL>=yqT|r|UJ?OP%5>z-vGS+ru2}jl%!) zX#M59PcIx-Td1?OH?e`VU0ohE>Ug@WTyX(DAdlbydznx)NjLefz3D~ham@|xW7)Up z^!iaE?YtANjG=+jG4sjuyh#}*#-=@Hs)w>TTy4zV#?#`V9>A z=iyu;6JiYHFA^iIuiC01u(+ZP%e>*mx;5=RM?Sb3^bmOoVmS=NZnM%0m(`}*Y<`hCL9d>HV_2pB0qG_f$Z(--2K>0eG zqZ?sROjsx8`=?X#x24KEjUNiz7;i~X^Rsd~&zbUNIFDpS<;vdWOkh>h92;MkQI>sq zplPe1KS~zeXKaNsxk6F?aaLO#_bZ`L%J+HtPiK44+*WCL+iMLiK5Iyl zxyFak%fVQ%yq3{vxo{N1-+t_yQKd{7z%wv3$nA*bv_s*-$RGUR+gaCOc|>;=Cs<8PTJF!Rd^eM2$wbUgxDx(A=W zLXAEp_1}qKz{pvOV|qk{X2$ICM*dJth27yn=nHq1={c?TIf0Z$?EFeOXCJ!-*uvA^ zE-E3C?x#^L^RWw+FFZxC-h!G zEN=@ctjsYjfq+Gd%1<0)wM6{R8!=@RITH%w=fSsVBiX$>-jVu@pDFvs%krrv)H{Bv zqi@(14u%MU?I696$zG-WL=|$QLQi`-f7G~mL@2CjVwmD2nd6!i$eEr$5%_7L`vDQ?i9`OXQPLEx zC3N=LD?&#?hR#fvef;RxvcqgK=%-g&H$hqu(0Jz1Z49e< zI+dE#5`5|GvOwt*OjYb}ezEjotx;d|v7X!7zA&b$6f;Z_E9acTR&rX2{>vr5=cMb8 zF+LeO`%cT7t?*aB$CuLUL0i&~i&cOA&he^9M`)YKZTlwp1*e8s>F#65YMC#u&8n!j zRzRD3Z*oJLy8`(l{uop7M$<7oF!^J~|Jq`X@_phr8|Dg2iMfs~SISXe^0684)^6Av zucNI^8k`4BEb&e;ADSr!S()X`DTV%ELVFo@Uo5~~7s6Z@LR=Ttt9}VGC@v(fS+n>H zrg}eV7-YK&)1Hab?8cX0@2O%R$o5xju5vUB zUKx$NoBM77qrQ$tP|6|2N08EQTI2juqD@nZgjWf5k;3^hsiq{Kk>UN2(7$$CtsOii z^sNtGzEgQA93>tWJ#R@w55|6%aH{x(Gnf@MTdp$rIaBym@hiia{QgxwR{XUwe8U|* zJoy8?7mYM32AEqk&2Mx&&GNr#{8DqKR`5=&`Q^DRr5!O9cGZWcTZUVFDpw;@LDMZj z<-{~ro%n9E^2M~BO}icGxC@TokX%KBe0^SVfKP070x%TRpA^d12u1c#8`CPR;JIo2 zrActbP1?ij&{VbSw9jwR7?GFAX?t%4$wY%>z9_|wvc9CTo$!KvgOxqSrps+I z)@?E*Nk5@FIKSl;q{A$MMSH}a^%>bN#hG{6QiWvnDd>9^)8qh?X*(xjTjGKc7ij~* zPI)|7GIR&MsFm|As>&&9(QUJ?UFR!Xs755Sc-{ycmY_~V6Q&RN_Y^4V@+oYz)Ff|k@vOyAQBmpcyitBHJ{ zTcov$N3+I&npEKj7Mbx^zl<2RedYNXPjK)I=tp@{ie@4S=l+CwN`B^i#G~-lT8kU} zRl8H~-ATW#gja`kAaST;{!l!oA*iNb-i5?bQ%qMZUtEcOl}YD?<0nH$37jFhTpXfC z{4jTMMd5sL9rjJuDj3=UkvDzS3r_-^pXlwrOpg{lPnlcL6JL3;7)8E{=^*zi(6aK* z9iL}@u<&vDHmujV=y zYH=6Lm5S(@YDn|Fs73I$_C7kz^|=(@G)`|bR3k99+prNH!MAf56HojOddKN)`HS0V zy;`4pwlaQLL|sG8xo(y3Ojy7aa^tupsv(TdgoOPh%)R}QgT&rs6f0yHl_V+J7f5>k z;>L0MQn3BATH$C(meBpJ)(}J`Rtz2qS-v8*P!dOs#c|L)IqdNYx_g z5=r&ymB1=3>idO27p(Y&K#s!fC#MJUH<((o9$KP)=Uo+Ln zPYrZiK*epr+XwH|o*q0Ljh6I%p^IZ_ zGn1KF$}cs6NLO70p3#X`|5yW>(HT|}zwnh^@nd#eK3iBnzchHDBAQ8${A+Qr?%Rao z_P|Vd{J16-!MGh-ss`c_Bx$WZpFO?^!=-QOphH4mn#uidBJVVx`yNgM5pJCVTIx9X#xo&?=YT^D(Zzl%N*sRXhSQmE#t1u zZ{G#eVcpS$t79UchQew%J}w`JZ?wvn9Hzzc4%~`Pi2ES&G2vfg3gva5Yx5wVXp!J$ z$9@CL8ZGq3aPp*l>hr%SRJcRunt#+nkUZH=0&KEQ1nqnmON}tojQR#^9>Ve!0(s+C z5C&!ABah+-WMiw^LWvbIlS}MMo;tfnNn$yM%0Ar(e8-1hmkG)-Sf}J}#*^L;QTR7} z>SuHF1sf9XA2klBI59Z=XzSqL;A6^?@^(+Gz;jPh@#kC~`820E^7l#=Q(!F7BWV(6 z6q^PMsoYAP{8K8%LfvThJn7*zf1jZ!s}9PlK`4obxi74>i5-ik+x$7UAUl@FlFnaW z_4;MOVBTanKOEmUbh3LYS+bzInRWR2pjHTl%opY-&6YwUUYeyu?xC)s{hy}-w2d`( zwd(uY$^V9%`aV)9T1}-SCwAx!6F|)4%T2wb38@IPD4@YIw@=tV{J{{Aq-7S+w!8BF zBr6~sjj&_15?Vnu$r^((;4g3L(fbgaqU^8RwS-R%A+k8;>W?3cUi z9opw5d@nM33ACnnw(DFY|WzNYAag5i`G`N{i?N~8ivW4~JVX{T`bYpUaJsWre zo1@0rwSu=sITS?9xXY?n+WK%%zrP^1zE|XxEr}$3 zH$fbVq-maX@cZo5?Q3aQv5bi74Rx$k@8|Ghg^$6rOB{J}C)mZD?tSzyJq~ zoyp|^EE=kZcZ)}!uh^cQo|4xlR3uCB2<=f&qbx4AL^Sr`$j z+rER_=+G7{SSlH>aj$+@3*Ebx z2k=^ES4p@z`ZNJ7%~alt+o(FpEOMT&=7I)ljyb}B9wYYmse)t5?hx9q>wXiTi{`lW z7F2i(D8QioTjuVQr69pWX*348r7O@i2+9~Tk%IF|tFJgTAWWZ7wJZoP=Q-$&VS8~h zUHhPnCcOQ}7X@^>&eHd&yTZ`GXLU|ONu*ek)bBD$EG?LJMc}Py)6}puWBFEXm?O|) zzm|=Px;;QAe9+b}CL)5XbnhVi`*;v|hirWM4Hi5bbWs(eDSU|@+!h^7NpwMr0>NBT z0^Nf3$GCg5)%5vRN%D|M#y7Nbso*zC?ZqFu*xq%lkREvov10H*7Z0ey*x3KV@)a3A zepab-V&TG4$L}0PGnD8k29K|BvhVS zJ7U$k=Q+xhnp*=08T&~P6dydT(*7-3@j#pa6| zyRp>x;X`Lp3hC#$Z653iHb1FHk-#Az!LHiIgFOTdhsjgnjS2n@2qslw91_TV&f zUj4Q5;Sao@qUgU96e)2)8^N2#=rI(d(t_hQCn{t^$|o9lz@cN%5A{$2boK(~#5TnO ztDy2hMY&>z$$~IJhfiU~Wg&RFmCh^=t0h#N-wHtAfgB4%^Z@Svh==nqsTV^3q)lK! z!hWH+;&t?(>{3AEP-7-|ddt)(GP~lYkF;YLKx-h*RQF%5tiDwH znq+TdFQanbe+LQkpyd!kNjJrY*oBCLm8>lVQPhbHao|FrO3Rw3=*&3!LwH9l5==G5 z1UagpPo_t&6$w5t;?cc(GzlWt;P4j}iFgxr~T?ydpA>E(I zZL^Y;c+%si(Y0f6!PQuD-rR=OiA{%EjgrK}>*rWT^D#zz7}J5ECX4pGOerXg&1yOn zWtTaO4}Dq)Dvd|sqL;9yMnP|jZXBTKeDne&jgj*Kha%IK-2T+(Q3*#_slJWA3+Rj$ zJ~_%~9BWJ8k%iNlMPWJ2>YivQaiUq3^j2-z2Q0O*kx%&21Q7a0pZ-gQU>m{A8{YeNJdW zTo_sMO9oiiG*Acd1-@aKiF~Ub%vm()U}XDo(Zm7x1Qq&otjt(*RT&S-^_%OKR~OEo zHcKtH#mew5>8&U`VX|n`*;q|zur|<{xM8V*{6iiikz}68IQ1mM*ykZP2C6V5 zca-}t1(IZi2yN{o$H(%ov^P&aCXw~?l0%(98qf3Qs*|{}O1UxgpIE*{n*>>Ef+p3L z%O)$BYLdxHXH_!u6DIJACK^G-m^la>{k?p!{iN1sBbLEZLHkKGs3v$?qrEs(MA;Ho z5XOdL!VD7y&4CVSV6!#DigoJ^yW9MU7aW;@x+YSf)Wj z7zp0qZk>!Hq}rZ0ugK2g6+OPJ3(`a#dDEyQ`NZDRC%pd9qqC!pUN{3AZUkz6(*7-d z@aYnrpyxxB4Q5Xn5GBZy+0e~co`TD$CP9)oFm-%;9@K~_cg}5ClhjQWWKoZ@PszOm z9(2ZrGoVdB0qW6~tK(z16(KHc4a)=!QwE)J8oKJsGw4)0a_|s0sW~6L1WAK)CUAf@ z!+$5s_3*mJ@M1@xssMB$0?K{?+hW}zu9wIA9SPUO3ZQ^yfG_9_-OS}{ISt=Z^&LkB z$vdAafm+aV%HEU^kor=0T3{Q%#d1+e2`YqXchdmh>HEG&z)9i|uiWo*q;W)$r7|Q4_Z)81~7^kPoa+ zYxNfN5e#c=?p_+}yFlmY-bc|kHecsBO}Y%HJWak_!@>7C7w9umRC%(WdYk)$zN7yg zod0bepW}uvs8IY+LAP$ z1@-4S?s^zkeSRO_!;gCvIPnWER0y88U*phmx^N;KJp<$Tj0`Po3?$skx`Ka1>-hcl z3jMNT{79mahptX+l;v?XNxnm+V~u*ns#g3cEn(dbXc`kb#Y5NRHA)#Q9vOT?G|d={ z3i4qAVY1^6IGC#G)7ugm)AQCypvgS}ivXiD`FB{l9otG?>K}y)A(Sc_Y}gA>Q@V4q zAdRY6$I}Qa8?Byeg)RE>GI`~8 z6^go*K#inw-gjGEli+ohr~WyP5noT~KI-%h$X~8KFPT?9B z6r-hDz#C+eGSOAx__7Vi7&}MrO$p4h7$8_v*;zU-iDcxBMnYM1Iu5 z{^-y-P`z3Ew-3$HaC+;dXK4EhOYWdn4ER(sV8cJ9bhek)@l>;Dq7HN-23>dobLQ|D ztnyBN*=m`s1-=q2IR|-xEDw05-y{$como+NX+ZJSxVTOPJVtMlSR9XZ zin_mmBrtMT-o(Tv2E6-|9V6LA3T*X)RXU(+6llSg7&Yl7Nw7l~yn^UL6lECYkQNk_ z!ONR}`K+D0GvRgQt-13lu-##f!~<_(s~mX9#{rLS9_Yd=(k9TMHyE%*a#!PCrLY^~ zcQ;({hs?YD&>MW%KwpQ#CmuD5GS-LalvnvGyfoTAS_(LXY8tsF)rt8jBOt$5(T?-# zYHGm28AoOwAG!r9O;V_H;?N^D)Wj=&52wK@#fBS$N;4nu<`*RcRUj`k;%rtl;Rh-v z%)!9X&&+3jC$5Hl;P9OItjb1T6FiA$IkY7<^mG)(imwxzIf!1geQ;b&Y zz#PRA&m#AjYNm1iI~Nd+%ZE~974PcVC(bn-P(N^CU5Fk}13~9i7|)VlE(*e0UmR4*2uJWt zS0{q8ngsD^IpL;Qv}mw)5N#3=19dU*?)8rOjJ>Fv7le;Ff`uMKAVkus{53NB)5nnK znw^7BQ9bcMKe(sgB$*IBdWP3b9_)_x9y13M`uNVhDS8P7OW%`f17}f2Idl+rzDZ45 zBw;5<7>{M17Fa1wPY|vK(pRfxU08UUG6F93k#MQftpY)tH^nc`{=@k@H^MBVV79IM@TrxBn3d< zpwpKZ_dA^8;G=Ofu+gJ9eh*QSY#^*^l>Iz`s-UkfP=X0QBZ7bA2D;4-%l{Fc$wRn4 z(=IVo8VB}6J1ks;=L4@qA9wXd{rEB|$QVCI^i7ODi!aG@4;LN99rH=6hzm0NGhnX= zFZnfUmPEPAEj<Dq*uRwv@bhQHMWPPwJ^?YHEh*!O?X41_$Woxn$}d@=F@To}KvJIo8XkYMosfV3 zWnwsAUlwx&8;o$+}QuJSVNMGuYt`Tv}w?&MEm67QtpXKw4-3}fsbEPWcm#lE(`Wi1>U=^ z!{Ke+@Bysm5O}>V166iEDt-}QB}^BnhNLu_%1>T4a}L-52drJsMWJt&NakP z8q~t55Isz%=s!c+N|byji0M(Uti6iXNMl4qkB!BdgJZp70B!F%9oFqB>g(9Z`@ANps{aCyhsJD9wkRde0VJCV`wD6!vl8(RRMfZFUzF~Pw+~J2o%N+;t;j$p%CSoWmcvBnrG7Tkq zWQa!>0?);w!+>Xi(mrqlJymWy@$@N^U;g+uCuklo#}&QAONgfP2hbCxTHrcM{S}ZG z5BxWd7`?4jhm@7!d(eQAAp^V_H1M39Vw8*%7?BEHh@J+!iOk~Df<3|g1#~jf> zFR>TG?L5FTfV<|RtXt{>8xo=A@!fPp3FE`-q`-zr5_?g?_c5lW?^5Xh0cbVrt_6-7 zA6X~^XoTdh6q(-1vOc{2YIrdqP!$2Lr-Sm4xz6^nUgnq5ZB{C<53Tz0Op3A z*-8j1F@8N)D=SAIkpM}5wVY7R>N#}iAcvV3Lh;iko7`e_)vCz z!&*QvO7NL>F)YEpBt)qosyTMrAfKWR>?jE}|7!9^$QJY60U71(L*DE8v>6vr06o;7 z%=J?rEmnXIUI1>e2g@NnXdiv4=`Mmm;4dY%vU!Wi#?qjibMQ$pFc2?^ItkzLEV1OG z5LoFmqVAGH>49;|qD*w=MCGLmrGGO05*mmB8@?+ZjuhV6KGm))pc7E!3cx2{LzFNd zhyJmM!bQ69URR~fh=Z;vpo?6ATWsnNJ!QEh|3#Uv z4?!I{05z)AIE#JiWu&@ws7HB-O@2see<}`th`Nh2@olCzpxF+dY$G1P` z(<0Nj85sUnV;A8fq#--@+M^eCSVd0FMtvda*w;e9@252E>KXx3%bNi-!@ok{j}02o zw|u0Q#K3K%5O}gcBY{^Xoyasi>Ejt;g#yZ-D?F?@KN(+0{^;p#f2f9I$+*4EUmne^QP#3gs$55_LUlvZ{-|+lJM^Q4DiD;01 zcDxsA;jXVrP_tYfcda{9-V+qKYfQTIK?u!j;H<$(=aVV%?bla99DW4|VfFrRJwGO5 zGgz8=PHuvkSuhD{=^4ej-2Gy_4*!JjHP%qgwPrS%ZBXnrUh(cpym1$NeeXPjw6}XY`*zs7I)>ffW`klDFNe4&%Da5Gj8rI*S?td_St@zan|8ejNU`h^ zpS}sl)0&SDOjk9eFW&Sj1%8ha(l%aa17{n&ZNiah-b4DWkj&~eGYiZ6nGO=aB4bQ( zu|KkSqO595w#eXsqo=D(QbQ^`r!y+S0uC_20t>N%TAT7~VpUKq2y?r3fKRE=1afD) zv=|E*4P|QYjd%LhPubo${^XSL7&eRjjchtN(!z*Q&tg(J9dTOpWao=Q94nds8aHw| z{G2-M0K3Cmrb#fh_A@Q~Mrw&$x=`mM{a8i3o(Z8uG~Vj{`5TCre~LPrhCUD;AqLn4IyTRc-z`Q-`}$M`RR}9t+}^bQ{*S9 zj5ho(W8Ryi-X4;5uNS8+y#GFCoj;+J1z}0&H?9m9U0&=MJ3Jc>H&_q7{U}!Q(88?E z>+_QZ$(N~G6q?v?3{A=V7rpHrR!UMKZd>HGT66guh?kmU)JrXW3&}Gyr_|C2Cx?Z7 zb7XRqWb4Ap*5%7&nGwoCm4zcFJ=4MBUcW;)d(n!sD;0m{}8r9#dbkQ_VpL zmUw~_mvQgSUll{AZ8hOCKDV>KU(oC<7S-HRUe!v}A!=JYb_j>0=Cm_pzBC|tt~={p zPKqwfly=1TOlw=2Q##j~CuL^eLwfShF6#7Z_8Yy_aUD??`HKD9o+4xQk+fnyO^d09 zv;p2*zt2mb23G0c9Ek6O4%FT2#|=!b9NwHls%TCY?B;Gg`3;c#UO`P_)aBh-z0W?E ze!neq-jO|clWLyX=I3DW!PsNZc`cB(KTm2uOONK%&G~er41t1JocpW)(TNvgcC$Fd znExcb!uXdL;K=MIyGo6Z^LW{SLXhS()9=K)b{D(`d7YeL;4&sod%F3>%3m^?9q2@0Fe1#tSP- zkSod?2=@ilw^}@&_64=aZAQnZi^%6I=O%4eCB9@U=OtlKJ!J1ii_-bU&7k~5fCEPM zf@sMaLi_pbY15oY8KkrKsC=G!Cjk3Bs!`KdcG5n2={)L13tmOuR<*ZlewV{_p2yk4 zIkn5Y+yMnYaCLOJ@-S`}ecD)~J=#jb-T~t`+k-YHTq5=TdBvpkf~+Xw=r-ul0>m?n zC74m4D%sFCeo|lN)1aANO>~}ehL?q?%jamwhd3nV!o2Uv#1Ev?O+|Np4RQ_25$|^n zf86WitWN57(-wYNf%KSZ^>|Vx9DNH3CW0+7ZXq=ruaGYw(z0_aoSST-W0zCq*uPGW zf9b~@CMp>td_`|ZF9VP}b9VGRi(-^manE!Xf7V4LlvaDR>3A>72q9Wq$j|nI(bC&Nut@e0We+T&!xIa4>mNnV? zo(-v$uog$UzU}e4ow1MIyRI4{ySj9sC7ru;$hd^q0EE_ag!+E&CA4fBYW>``lsw1{r5~rtt245D{$w` zD*EQSYNhQm0=b@z%z)f_++HFxE*&B+kzax19a4fF$PE&5Jd7NYMT#S@ksa4noo8)g z$X%GrL7VsadW}V!;0AHrIbq}e;xk0h!lzG*OTa%CoC zzPH1@k~>~2r)G=zzM3Visu8z?6!_HUi-A735@m>3f*VNQ^^n8~<;GNS5Hj)O+;a7{ z$(v0ecvQp8dq2JyBU+NRa>ThZp$22#`goh)m5O|e8 z@ACy(UgQgLx>M%~@$T|#O+Q52DF;LFZGqWwS4Q>r&K8nk$i-vr)FrE}=8Jxyy6Y8V za@MWKSa_@FP1Q+V!|QvQr*7$}Q?dsX9s!0Wt*tKYS=JbD1vYe@^`! zACc?sUEXlDy(7Ne_4ey($7y)x_^P&z5^b_gr-)5~AAg37ZzUQcm(YcOzT!nTBq4(P z29k44EjME^QXOpPI>`>)a5Ag2A@jmD=o0Z5n0 zF{B%DIYdmH_99I`Q2J((uj6$1Ta(M}H$B;&^EP_`w*!I9yqlt_CthBB8B&?%%A#uW zYVrNSnP}Ue3RD}Stn8)S{ZA)0KKI5ZUNS6FJeox9M>d8>WdeJnV+(R<;X42FDmXdk zdSkd)zf@25-FX`Y;A+la+#}HC!2SxVvo?2?%^l(A=-G`nAS- z-d1%0j>)`YTwkJapmdq2yS|!2-l?B~Md>Ju%feoH>`7+?JfXeHa{yknE)D zs-$;g94+>Ux4N2oeIpK2{jv8=Z}`cj!}+d|X$hp@DUw4QVbo8vh_Ytk0T%~0ACtK6 zJdl#Cowlp+ooxTCj8UDtG!1i`BCD&;egmihI`jcLOp<8c^r!tbj|eQ|cshoANdpN|9vPO;yh7zb9K|EH+C%>52)aIG-jifFD@-=wqY&S=geL#9cSa_kf{bCOBj&PwsZ}IXV zwM}zsqwVt+zTJX@qFLKrzATK6sReFfrxhmKP80D*Hx-)CkyH6=wrT!!ja%?I&U*?8$Frd83VZHIkbvysGA|JKoiPA_)B1F8)mVi{2vf%HOHi zNPo-GYXje#5A+$jYAOFxDG7l79cacmAQ8p1B+I@LhsE87p3N2^JmmC{sPQg zSwsL0wN4;hdyF95Yw`Me&rHj5r@*?e*_8!9FnhY|uK^|?}fV`P} zVghqW=+4sPGK63j^6mB((zqr%;&!8D_2$6wzF&eioHikidQwKpB|)(VkW65-cr*p! ztyzd%e~7pr8$tvRBX$bb?=ztD99iIoG{}4;%tlzwxLYhn=yuh%i6Fp7AQdkG zBCGe1d6y0sflS3WE^|nwDe><7YihGA*>PZf2Y7ye{&9Aj`l|-{Wd4@F5h%mw_vgEr z6uEWF|8(k>@aC$GO2_p8LY>U@4Y&mKPY#gQ8tYW4ZEWknLbB-Hvd-?LtZ?HQ3uvCRCerTB9`eHB3?g{VkU+JMw7!sL`YwUoo<-(Dkd*(i`uJ{? zC~E+Q0W6z~$ZkVxVD<=hzxl5S8z4fVJ_v|VxZDYTq*_24?+V_@(h1JrX#fL8^mcd` zaM8jf`1>B&?cNRqz;c@oAfF%^yF1c!o_WMv-tzZ2#{*3DY)4vMO8!rwlT=Fyv#$v=O6ri|w;t zZ8ff6-1SQCcCvDPR?>~(LKgeln4)^)uOojzF@3fU=disYE-57Y^xm zq%jY`Cc})`^!x#GCM>FlU!iCT=@x3y(%?F+i^R|7$uk}bN;A404MnE?N0@3+1@N79 zwSWGq5UDt{_p*^0`R8pPF=IL%-!M*+rDl7}Ay8nP`W`;ivv3RI%1`rKx5 ziGOS!x;>^u$VoU|WRI=>`PO^*i*eXrusT(prq#7k>T?vH$L~{8S?T6JJm31qLf1uxl4%ym)FJ@9F9WUz{F>j-~cYuRqk zD-0)#$(e6}zLB7W!58U5$3?k4gM3CHdDW#kHy z;SPC}ET%aPa_=UEhtcSY6!P9*5qS!_lkdtdrAtJT{3~NbgPasnYi_6<(02+r>rC-c zx0@;>zT_|HGT~|R{Lmw-5L2h;wZn-UGzS^SUMT+JfA^erJGp*AF-$Z7aGAmG+NH1 zZ+C(HT22Q}98${n&}J=h*%n9uH6i1BXh9}-+2_c@P3&O=Dtk8jg4%iXf|y?i$nEoM zJQ1hsm&VP)h-(kuaG?OsHYB-ps0Ixr@-D~A4v|Wm9zd=B+JJoiXTet~1?hF4xa@B7 zKiL!>aWY8a3hjGayro!ycU2daZAYHl7D{puDT91c4U`S3th$;6rHD0$+qKaosGBe} z35DPV_**}nqAe^2BSBRDXeI4|P!7uKIaA&Smy#_k0|Tb$@rQ_<$t8TbfLh7+kOrU0 z<=yfRVpXtmA;|xPqSB%OY>WTmS}rFN(6i-xajlpgiFY<)FBJRBMjNp6I5p0Tdxc5pCB6p&U z21GCZRT2l3%#Ly(RZ~T{y`5a>E&pF^#o!-2?&v!E1KYy}38wCNBE@)9F{zE){ zA|RAkrl14FrFd5!KF_TcF)T*zrF5wPshstNOObm|Zh`;tluDV{o(g#K|0gKPl?OtP zp4?O6V)73a$B`9|JJfc3FPa;8AE5pddglY7i%%F@uf4wARc5m+&x6pa`{allD%r}R z7e11Vzj${~4MjmD-hCY$yol~yGNF+hYWXL_W(vgA0m>%SJV4F3$=tzl{`qsm@vdY% zs-t@UDZyvyk$9~I;PKCyMIaZuV;5rTYdO?u?hMOhl=nJ!EYy820MmWY{H_ELxm4=< zp6@{@qmxUi-sjAC9iaW}XC65J;h=#Qxqjly18|5sY-Bru)a^*!?i^0`+s^`lg#Q!B zF8BIu&{r@WaJm=7t)4(UC4cRv3svnlo%8r78HUr;^+ zD2soLWmS~sDd(YR0C&zKpxC>pcF$XxywqO$mwT#2&wt`Ofu7#QahJ8p zNPzYO4Rk}mwCI0m7d`%|>G)sP%bcY4w&J;gIUxYh=JFI(;2%pYb3;y`nETumTbGxi z;{%7t7I=$;)Q3uVhr!+&({GLQx=)qDfr4Q Date: Mon, 16 Feb 2026 21:24:02 +0100 Subject: [PATCH 5/5] chore(lint): for pet and torch.compile --- docs/src/dev-docs/changelog.rst | 8 +++ src/metatrain/pet/model.py | 3 +- src/metatrain/pet/modules/compile.py | 76 ++++++++++++++++--------- src/metatrain/pet/modules/utilities.py | 2 + src/metatrain/pet/tests/test_compile.py | 4 +- src/metatrain/pet/trainer.py | 56 +++++++++--------- 6 files changed, 87 insertions(+), 62 deletions(-) diff --git a/docs/src/dev-docs/changelog.rst b/docs/src/dev-docs/changelog.rst index e3656642b0..7aa1bb6ae7 100644 --- a/docs/src/dev-docs/changelog.rst +++ b/docs/src/dev-docs/changelog.rst @@ -24,6 +24,14 @@ changelog `_ format. This project follows Unreleased ---------- +Added +##### + +- The PET architecture now supports full-graph FX compilation for training via the + ``compile`` hyperparameter. When enabled, the entire model (including force/stress + computation) is traced into a single FX graph and compiled with ``torch.compile``, + providing maximum kernel fusion and consistently using scaled dot-product attention. + Version 2026.1 - 2026-01-07 --------------------------- diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index c5cb7e5b5a..be8a146cfd 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -1415,8 +1415,7 @@ def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: def get_checkpoint(self) -> Dict: # Get state dict, handling compiled modules by removing _orig_mod prefix state_dict = { - k.replace("._orig_mod", ""): v - for k, v in self.state_dict().items() + k.replace("._orig_mod", ""): v for k, v in self.state_dict().items() } model_state_dict = dict(state_dict) model_state_dict["finetune_config"] = self.finetune_config diff --git a/src/metatrain/pet/modules/compile.py b/src/metatrain/pet/modules/compile.py index 56287e1c44..f27122cf93 100644 --- a/src/metatrain/pet/modules/compile.py +++ b/src/metatrain/pet/modules/compile.py @@ -9,7 +9,7 @@ """ import logging -from typing import Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch.nn.utils._named_member_accessor import NamedMemberAccessor @@ -22,6 +22,8 @@ class _PETBatchForward(torch.nn.Module): PET is registered as a submodule so its parameters/buffers are visible to ``functional_call`` / ``NamedMemberAccessor``. + + :param pet: The PET model whose ``_forward_from_batch`` is called. """ def __init__(self, pet: torch.nn.Module) -> None: @@ -57,13 +59,31 @@ def _make_pet_compiled_forward( output_shapes: Dict[str, Dict[str, List[int]]], compute_forces: bool, compute_stress: bool, -): +) -> Callable[ + ..., + Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Dict[str, Dict[str, torch.Tensor]], + ], +]: """Build the traceable forward function for ``make_fx``. The returned function accepts all batch tensors and the model's parameters/buffers as positional arguments (required by ``make_fx`` with ``functional_call``). It returns ``(per_structure_preds, forces, stress, raw_predictions)``. + + :param batch_model: Wrapper module whose ``forward`` delegates to + ``pet._forward_from_batch``. + :param param_names: Ordered parameter names for the batch model. + :param buffer_names: Ordered buffer names for the batch model. + :param target_names: Names of the prediction targets. + :param output_shapes: Mapping of target name to block key to shape. + :param compute_forces: Whether to include force computation in the graph. + :param compute_stress: Whether to include stress computation in the graph. + :return: A callable that can be traced by ``make_fx``. """ n_params = len(param_names) accessor = NamedMemberAccessor(batch_model) @@ -87,17 +107,22 @@ def _make_pet_compiled_forward( ) def forward_fn( - edge_vectors, - element_indices_nodes, - element_indices_neighbors, - padding_mask, - reverse_neighbor_index, - cutoff_factors, - system_indices, - neighbor_atom_indices, - n_structures, - *params_and_buffers, - ): + edge_vectors: torch.Tensor, + element_indices_nodes: torch.Tensor, + element_indices_neighbors: torch.Tensor, + padding_mask: torch.Tensor, + reverse_neighbor_index: torch.Tensor, + cutoff_factors: torch.Tensor, + system_indices: torch.Tensor, + neighbor_atom_indices: torch.Tensor, + n_structures: int, + *params_and_buffers: torch.Tensor, + ) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Dict[str, Dict[str, torch.Tensor]], + ]: # Swap in the provided params/buffers via NamedMemberAccessor params_buffers = {} for i, name in enumerate(param_names): @@ -105,9 +130,7 @@ def forward_fn( for i, name in enumerate(buffer_names): params_buffers[name] = params_and_buffers[n_params + i] - orig_values, _ = accessor.swap_tensors_dict( - params_buffers, allow_missing=True - ) + orig_values, _ = accessor.swap_tensors_dict(params_buffers, allow_missing=True) # Compute edge_distances inside compiled graph (differentiable) edge_distances = torch.sqrt((edge_vectors**2).sum(-1) + 1e-15) @@ -139,9 +162,7 @@ def forward_fn( energy = torch.zeros( n_struct, dtype=edge_vectors.dtype, device=edge_vectors.device ) - energy.scatter_add_( - 0, system_indices, per_atom_energy.squeeze(-1) - ) + energy.scatter_add_(0, system_indices, per_atom_energy.squeeze(-1)) if (compute_forces or compute_stress) and energy is not None: (dE_dR,) = torch.autograd.grad( @@ -168,8 +189,11 @@ def forward_fn( # Virial: sigma = (1/V) sum r otimes (dE/dr) virial_per_atom = torch.einsum("ema,emb->eab", edge_vectors, dE_dR) stress_buf = torch.zeros( - n_struct, 3, 3, - dtype=edge_vectors.dtype, device=edge_vectors.device, + n_struct, + 3, + 3, + dtype=edge_vectors.dtype, + device=edge_vectors.device, ) stress_buf.scatter_add_( 0, @@ -188,7 +212,7 @@ def forward_fn( def compile_pet_model( model: torch.nn.Module, - train_dataloader, + train_dataloader: Any, compute_forces: bool, compute_stress: bool, ) -> Tuple[torch.nn.Module, List[str], List[str]]: @@ -258,9 +282,7 @@ def compile_pet_model( # edge_vectors needs grad for force tracing tracing_edge_vectors = edge_vectors.clone().requires_grad_(True) - logging.info( - "Tracing PET model with make_fx (symbolic tracing)..." - ) + logging.info("Tracing PET model with make_fx (symbolic tracing)...") old_duck = torch.fx.experimental._config.use_duck_shape torch.fx.experimental._config.use_duck_shape = False @@ -286,8 +308,6 @@ def compile_pet_model( torch.fx.experimental._config.use_duck_shape = old_duck logging.info("Compiling traced FX graph with torch.compile...") - compiled = torch.compile( - fx_graph, dynamic=True, fullgraph=True - ) + compiled = torch.compile(fx_graph, dynamic=True, fullgraph=True) return compiled, param_names, buffer_names diff --git a/src/metatrain/pet/modules/utilities.py b/src/metatrain/pet/modules/utilities.py index 2414f2f1ca..2a3e88f4a3 100644 --- a/src/metatrain/pet/modules/utilities.py +++ b/src/metatrain/pet/modules/utilities.py @@ -69,6 +69,8 @@ def replace_silu_modules(module: torch.nn.Module) -> None: """Replace all ``torch.nn.SiLU`` instances with :class:`DecomposedSiLU`. Recurses through the module tree, including inside ``nn.Sequential``. + + :param module: The module to recursively modify in-place. """ for name, child in module.named_children(): if isinstance(child, torch.nn.SiLU): diff --git a/src/metatrain/pet/tests/test_compile.py b/src/metatrain/pet/tests/test_compile.py index bf88f03403..7ca028b824 100644 --- a/src/metatrain/pet/tests/test_compile.py +++ b/src/metatrain/pet/tests/test_compile.py @@ -70,7 +70,7 @@ def _make_inputs(n_atoms=5, max_neighbors=10, d_model=8, dim_node_features=16): def test_compile_cartesian_transformer(): - """Test that CartesianTransformer compiles with fullgraph=True and SDPA attention.""" + """Test CartesianTransformer with fullgraph=True and SDPA attention.""" ct = _make_cartesian_transformer() compiled_ct = torch.compile(ct, fullgraph=True) @@ -145,8 +145,8 @@ def test_forward_from_batch(): from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists - from . import DATASET_PATH, MODEL_HYPERS from ..modules.structures import systems_to_batch + from . import DATASET_PATH, MODEL_HYPERS torch.manual_seed(42) diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index e7264601ca..c114be5dcb 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -5,6 +5,8 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import System from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, DistributedSampler @@ -37,9 +39,6 @@ from metatrain.utils.scaler import get_remove_scale_transform from metatrain.utils.transfer import batch_to -from metatensor.torch import Labels, TensorBlock, TensorMap -from metatomic.torch import System - from . import checkpoints from .documentation import TrainerHypers from .model import PET @@ -62,6 +61,17 @@ def _wrap_compiled_output( Produces the same format as ``evaluate_model`` so the loss function and metric accumulators work unchanged. + + :param energy: Per-structure energy tensor from the compiled function. + :param forces: Per-atom force tensor, or ``None``. + :param stress: Per-structure stress tensor, or ``None``. + :param raw_predictions: Per-atom predictions keyed by target and block. + :param model: The PET model instance. + :param systems: The input systems for this batch. + :param sample_labels: Labels indicating system and atom indices. + :param system_indices: System index for each atom in the batch. + :param train_targets: Target information dict from the training config. + :return: Predictions as ``Dict[str, TensorMap]``. """ from metatrain.utils.sum_over_atoms import sum_over_atoms @@ -82,15 +92,13 @@ def _wrap_compiled_output( values=energy.unsqueeze(-1), samples=Labels( "system", - torch.arange( - n_structures, device=device, dtype=torch.int32 - ).unsqueeze(-1), + torch.arange(n_structures, device=device, dtype=torch.int32).unsqueeze( + -1 + ), assume_unique=True, ), components=[], - properties=Labels( - "energy", torch.tensor([[0]], device=device) - ), + properties=Labels("energy", torch.tensor([[0]], device=device)), ) if forces is not None: @@ -101,40 +109,30 @@ def _wrap_compiled_output( values=sample_labels.values.to(torch.int32), assume_unique=True, ).to(device) - xyz_labels = Labels( - "xyz", torch.tensor([[0], [1], [2]], device=device) - ) + xyz_labels = Labels("xyz", torch.tensor([[0], [1], [2]], device=device)) forces_block = TensorBlock( values=forces.unsqueeze(-1), samples=grad_samples, components=[xyz_labels], - properties=Labels( - "energy", torch.tensor([[0]], device=device) - ), + properties=Labels("energy", torch.tensor([[0]], device=device)), ) energy_block.add_gradient("positions", forces_block) if stress is not None: stress_samples = Labels( "sample", - torch.arange( - n_structures, device=device, dtype=torch.int32 - ).unsqueeze(-1), + torch.arange(n_structures, device=device, dtype=torch.int32).unsqueeze( + -1 + ), assume_unique=True, ) - xyz1 = Labels( - "xyz_1", torch.tensor([[0], [1], [2]], device=device) - ) - xyz2 = Labels( - "xyz_2", torch.tensor([[0], [1], [2]], device=device) - ) + xyz1 = Labels("xyz_1", torch.tensor([[0], [1], [2]], device=device)) + xyz2 = Labels("xyz_2", torch.tensor([[0], [1], [2]], device=device)) stress_block = TensorBlock( values=stress.unsqueeze(-1), samples=stress_samples, components=[xyz1, xyz2], - properties=Labels( - "energy", torch.tensor([[0]], device=device) - ), + properties=Labels("energy", torch.tensor([[0]], device=device)), ) energy_block.add_gradient("strain", stress_block) @@ -525,9 +523,7 @@ def train( has_strain_gradients, ) ) - logging.info( - "FX compilation complete (will optimize on first call)" - ) + logging.info("FX compilation complete (will optimize on first call)") start_epoch = 0 if self.epoch is None else self.epoch + 1