diff --git a/docs/src/dev-docs/changelog.rst b/docs/src/dev-docs/changelog.rst index e3656642b0..7aa1bb6ae7 100644 --- a/docs/src/dev-docs/changelog.rst +++ b/docs/src/dev-docs/changelog.rst @@ -24,6 +24,14 @@ changelog `_ format. This project follows Unreleased ---------- +Added +##### + +- The PET architecture now supports full-graph FX compilation for training via the + ``compile`` hyperparameter. When enabled, the entire model (including force/stress + computation) is traced into a single FX graph and compiled with ``torch.compile``, + providing maximum kernel fusion and consistently using scaled dot-product attention. + Version 2026.1 - 2026-01-07 --------------------------- diff --git a/pyproject.toml b/pyproject.toml index 4025b408c3..18ea6df8cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,6 +163,7 @@ filterwarnings = [ "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", "ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning", + "ignore:`torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.trace_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.trace` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", # PyTorch does not want these, but mypy requires them @@ -177,7 +178,11 @@ filterwarnings = [ # Multi-threaded tests clash with multi-process data-loading "ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning", # MACE warning with newer versions of pytorch (because they use e3nn==0.4.4) - "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning" + "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning", + # compiled_autograd + non-leaf tensors: dynamo incorrectly converts this C++ + # warning to an error when tracing backward graphs through non-leaf tensors + # (e.g. edge_distances computed from edge_vectors via sqrt) + "ignore:The .grad attribute of a Tensor that is not a leaf Tensor is being accessed:UserWarning" ] addopts = ["-p", "mtt_plugin"] pythonpath = "src/metatrain/utils/testing" diff --git a/src/metatrain/pet/checkpoints.py b/src/metatrain/pet/checkpoints.py index 020b1e6590..f09c9afbb3 100644 --- a/src/metatrain/pet/checkpoints.py +++ b/src/metatrain/pet/checkpoints.py @@ -420,3 +420,12 @@ def trainer_update_v11_v12(checkpoint: dict) -> None: :param checkpoint: The checkpoint to update. """ checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None] + + +def trainer_update_v12_v13(checkpoint: dict) -> None: + """ + Update trainer checkpoint from version 12 to version 13. + + :param checkpoint: The checkpoint to update. + """ + checkpoint["train_hypers"]["compile"] = False diff --git a/src/metatrain/pet/documentation.py b/src/metatrain/pet/documentation.py index 8eaec9c03e..629e02c0ce 100644 --- a/src/metatrain/pet/documentation.py +++ b/src/metatrain/pet/documentation.py @@ -257,3 +257,14 @@ class TrainerHypers(TypedDict): See :ref:`label_fine_tuning_concept` for more details. """ + compile: bool = False + """Whether to use full-graph FX compilation during training. + + When enabled, the entire PET model (including force/stress computation via + ``autograd.grad``) is traced into a single FX graph using ``make_fx`` and + then compiled with ``torch.compile(dynamic=True, fullgraph=True)``. This + gives maximum kernel fusion, zero compiled/eager boundary crossings, and + always uses ``scaled_dot_product_attention`` (SDPA). Expect a one-time + compilation cost at the start of training, followed by speedups on every + subsequent step. + """ diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index 757e778140..be8a146cfd 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -416,6 +416,7 @@ def forward( reverse_neighbor_index, cutoff_factors, system_indices, + _neighbor_atom_indices, sample_labels, ) = systems_to_batch( systems, @@ -570,6 +571,78 @@ def forward( return return_dict + def _forward_from_batch( + self, + element_indices_nodes: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + reverse_neighbor_index: torch.Tensor, + cutoff_factors: torch.Tensor, + ) -> Dict[str, Dict[str, torch.Tensor]]: + """ + Pure-tensor forward pass for FX compilation. + + Takes batch tensors and returns raw per-atom predictions as nested + dictionaries (target_name -> block_key -> tensor). Always uses SDPA + attention (no manual attention needed since forces are computed via + ``autograd.grad(create_graph=False)`` in the compiled graph, avoiding + double backward). + + :param element_indices_nodes: Atomic species of central atoms [n_atoms]. + :param element_indices_neighbors: Atomic species of neighbors + [n_atoms, max_neighbors]. + :param edge_vectors: Edge vectors [n_atoms, max_neighbors, 3]. + :param edge_distances: Edge distances [n_atoms, max_neighbors]. + :param padding_mask: Boolean mask for real neighbors + [n_atoms, max_neighbors]. + :param reverse_neighbor_index: Reversed neighbor index for message + passing [n_atoms, max_neighbors]. + :param cutoff_factors: Cutoff function values [n_atoms, max_neighbors]. + :return: Nested dict mapping target_name -> block_key -> per-atom + prediction tensor. + """ + featurizer_inputs: Dict[str, torch.Tensor] = dict( + element_indices_nodes=element_indices_nodes, + element_indices_neighbors=element_indices_neighbors, + edge_vectors=edge_vectors, + edge_distances=edge_distances, + reverse_neighbor_index=reverse_neighbor_index, + padding_mask=padding_mask, + cutoff_factors=cutoff_factors, + ) + + # Always use SDPA (no double backward in compiled path) + node_features_list, edge_features_list = self._calculate_features( + featurizer_inputs, use_manual_attention=False + ) + + node_ll_dict, edge_ll_dict = self._calculate_last_layer_features( + node_features_list, edge_features_list + ) + + outputs_all: Dict[str, ModelOutput] = { + name: ModelOutput(per_atom=True) for name in self.target_names + } + node_preds, edge_preds = self._calculate_atomic_predictions( + node_ll_dict, edge_ll_dict, padding_mask, cutoff_factors, outputs_all + ) + + # Sum across GNN layers for each target/block + results: Dict[str, Dict[str, torch.Tensor]] = {} + for target_name in self.target_names: + block_results: Dict[str, torch.Tensor] = {} + node_layers = node_preds[target_name] + edge_layers = edge_preds[target_name] + for j, key in enumerate(self.output_shapes[target_name]): + total = node_layers[0][j] + edge_layers[0][j] + for i in range(1, len(node_layers)): + total = total + node_layers[i][j] + edge_layers[i][j] + block_results[key] = total + results[target_name] = block_results + return results + def _calculate_features( self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: @@ -1340,7 +1413,11 @@ def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: return checkpoint def get_checkpoint(self) -> Dict: - model_state_dict = self.state_dict() + # Get state dict, handling compiled modules by removing _orig_mod prefix + state_dict = { + k.replace("._orig_mod", ""): v for k, v in self.state_dict().items() + } + model_state_dict = dict(state_dict) model_state_dict["finetune_config"] = self.finetune_config checkpoint = { "architecture_name": "pet", @@ -1353,7 +1430,7 @@ def get_checkpoint(self) -> Dict: "epoch": None, "best_epoch": None, "model_state_dict": model_state_dict, - "best_model_state_dict": self.state_dict(), + "best_model_state_dict": state_dict, } return checkpoint diff --git a/src/metatrain/pet/modules/compile.py b/src/metatrain/pet/modules/compile.py new file mode 100644 index 0000000000..f27122cf93 --- /dev/null +++ b/src/metatrain/pet/modules/compile.py @@ -0,0 +1,313 @@ +"""Full-graph FX compilation for PET. + +Traces the entire PET forward pass (including force/stress computation via +``autograd.grad``) into a single FX graph, then compiles it with +``torch.compile(dynamic=True, fullgraph=True)``. This gives maximum kernel +fusion, zero compiled/eager boundary crossings, and always uses SDPA +(``scaled_dot_product_attention``) since forces use +``create_graph=False`` (no double backward). +""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + +from .utilities import replace_silu_modules + + +class _PETBatchForward(torch.nn.Module): + """Thin wrapper whose ``forward()`` delegates to ``pet._forward_from_batch``. + + PET is registered as a submodule so its parameters/buffers are visible + to ``functional_call`` / ``NamedMemberAccessor``. + + :param pet: The PET model whose ``_forward_from_batch`` is called. + """ + + def __init__(self, pet: torch.nn.Module) -> None: + super().__init__() + self.pet = pet + + def forward( + self, + element_indices_nodes: torch.Tensor, + element_indices_neighbors: torch.Tensor, + edge_vectors: torch.Tensor, + edge_distances: torch.Tensor, + padding_mask: torch.Tensor, + reverse_neighbor_index: torch.Tensor, + cutoff_factors: torch.Tensor, + ) -> Dict[str, Dict[str, torch.Tensor]]: + return self.pet._forward_from_batch( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + ) + + +def _make_pet_compiled_forward( + batch_model: _PETBatchForward, + param_names: List[str], + buffer_names: List[str], + target_names: List[str], + output_shapes: Dict[str, Dict[str, List[int]]], + compute_forces: bool, + compute_stress: bool, +) -> Callable[ + ..., + Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Dict[str, Dict[str, torch.Tensor]], + ], +]: + """Build the traceable forward function for ``make_fx``. + + The returned function accepts all batch tensors and the model's + parameters/buffers as positional arguments (required by + ``make_fx`` with ``functional_call``). It returns + ``(per_structure_preds, forces, stress, raw_predictions)``. + + :param batch_model: Wrapper module whose ``forward`` delegates to + ``pet._forward_from_batch``. + :param param_names: Ordered parameter names for the batch model. + :param buffer_names: Ordered buffer names for the batch model. + :param target_names: Names of the prediction targets. + :param output_shapes: Mapping of target name to block key to shape. + :param compute_forces: Whether to include force computation in the graph. + :param compute_stress: Whether to include stress computation in the graph. + :return: A callable that can be traced by ``make_fx``. + """ + n_params = len(param_names) + accessor = NamedMemberAccessor(batch_model) + + # Identify which target is the energy target (quantity == "energy") + # For force/stress we need to aggregate per-atom energy to per-structure. + energy_target_name: Optional[str] = None + energy_block_key: Optional[str] = None + pet = batch_model.pet + for tname in target_names: + if hasattr(pet, "outputs") and tname in pet.outputs: + if pet.outputs[tname].quantity == "energy": + energy_target_name = tname + # First block key for this target + energy_block_key = next(iter(output_shapes[tname])) + break + + if (compute_forces or compute_stress) and energy_target_name is None: + raise ValueError( + "Force/stress compilation requested but no energy target found." + ) + + def forward_fn( + edge_vectors: torch.Tensor, + element_indices_nodes: torch.Tensor, + element_indices_neighbors: torch.Tensor, + padding_mask: torch.Tensor, + reverse_neighbor_index: torch.Tensor, + cutoff_factors: torch.Tensor, + system_indices: torch.Tensor, + neighbor_atom_indices: torch.Tensor, + n_structures: int, + *params_and_buffers: torch.Tensor, + ) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Dict[str, Dict[str, torch.Tensor]], + ]: + # Swap in the provided params/buffers via NamedMemberAccessor + params_buffers = {} + for i, name in enumerate(param_names): + params_buffers[name] = params_and_buffers[i] + for i, name in enumerate(buffer_names): + params_buffers[name] = params_and_buffers[n_params + i] + + orig_values, _ = accessor.swap_tensors_dict(params_buffers, allow_missing=True) + + # Compute edge_distances inside compiled graph (differentiable) + edge_distances = torch.sqrt((edge_vectors**2).sum(-1) + 1e-15) + + raw_predictions = batch_model( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + ) + + # Restore original params/buffers + accessor.swap_tensors_dict(orig_values, allow_missing=True) + + # Aggregate per-atom predictions to per-structure for the energy target + n_atoms = edge_vectors.shape[0] + # +1 for padding structure index (scatter needs valid indices) + n_struct = n_structures + 1 + + energy: Optional[torch.Tensor] = None + forces: Optional[torch.Tensor] = None + stress: Optional[torch.Tensor] = None + + if energy_target_name is not None and energy_block_key is not None: + per_atom_energy = raw_predictions[energy_target_name][energy_block_key] + energy = torch.zeros( + n_struct, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + energy.scatter_add_(0, system_indices, per_atom_energy.squeeze(-1)) + + if (compute_forces or compute_stress) and energy is not None: + (dE_dR,) = torch.autograd.grad( + energy[:n_structures].sum(), + edge_vectors, + create_graph=False, + ) + dE_dR = dE_dR * padding_mask[:, :, None].float() + + if compute_forces: + # d(E)/d(pos[i]): + # as center: -sum_j dE_dR[i, j] + # as neighbor: +sum_{(k,j): neighbor_atom=i} dE_dR[k, j] + grad_as_center = -dE_dR.sum(dim=1) # [n_atoms, 3] + flat_dE = dE_dR.reshape(-1, 3) + flat_idx = neighbor_atom_indices.reshape(-1, 1).expand(-1, 3).long() + grad_as_neighbor = torch.zeros( + n_atoms, 3, dtype=edge_vectors.dtype, device=edge_vectors.device + ) + grad_as_neighbor.scatter_add_(0, flat_idx, flat_dE) + forces = grad_as_center + grad_as_neighbor + + if compute_stress: + # Virial: sigma = (1/V) sum r otimes (dE/dr) + virial_per_atom = torch.einsum("ema,emb->eab", edge_vectors, dE_dR) + stress_buf = torch.zeros( + n_struct, + 3, + 3, + dtype=edge_vectors.dtype, + device=edge_vectors.device, + ) + stress_buf.scatter_add_( + 0, + system_indices[:, None, None].expand(-1, 3, 3), + virial_per_atom, + ) + stress = stress_buf[:n_structures] + + if energy is not None: + energy = energy[:n_structures] + + return energy, forces, stress, raw_predictions + + return forward_fn + + +def compile_pet_model( + model: torch.nn.Module, + train_dataloader: Any, + compute_forces: bool, + compute_stress: bool, +) -> Tuple[torch.nn.Module, List[str], List[str]]: + """Trace and compile the PET model as a single FX graph. + + :param model: The PET model instance. + :param train_dataloader: A dataloader to get a sample batch for tracing. + :param compute_forces: Whether force computation is included. + :param compute_stress: Whether stress computation is included. + :return: Tuple of (compiled_module, param_names, buffer_names). + """ + from torch.fx.experimental.proxy_tensor import make_fx + + from metatrain.utils.data import unpack_batch + from metatrain.utils.transfer import batch_to + + from ..modules.structures import systems_to_batch + + batch_model = _PETBatchForward(model) + replace_silu_modules(batch_model) + + params = dict(batch_model.named_parameters()) + buffers = dict(batch_model.named_buffers()) + param_names = list(params.keys()) + buffer_names = list(buffers.keys()) + + forward_fn = _make_pet_compiled_forward( + batch_model, + param_names, + buffer_names, + model.target_names, + model.output_shapes, + compute_forces, + compute_stress, + ) + + # Get a sample batch for tracing + batch = next(iter(train_dataloader)) + systems, _targets, _extra_data = unpack_batch(batch) + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + systems, _, _ = batch_to(systems, {}, {}, dtype=dtype, device=device) + + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + neighbor_atom_indices, + _sample_labels, + ) = systems_to_batch( + systems, + model.requested_nl, + model.atomic_types, + model.species_to_species_index, + model.cutoff_function, + model.cutoff_width, + model.num_neighbors_adaptive, + ) + + n_structures = int(system_indices.max().item()) + 1 + + # edge_vectors needs grad for force tracing + tracing_edge_vectors = edge_vectors.clone().requires_grad_(True) + + logging.info("Tracing PET model with make_fx (symbolic tracing)...") + + old_duck = torch.fx.experimental._config.use_duck_shape + torch.fx.experimental._config.use_duck_shape = False + try: + fx_graph = make_fx( + forward_fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )( + tracing_edge_vectors, + element_indices_nodes, + element_indices_neighbors, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + neighbor_atom_indices, + n_structures, + *list(params.values()), + *list(buffers.values()), + ) + finally: + torch.fx.experimental._config.use_duck_shape = old_duck + + logging.info("Compiling traced FX graph with torch.compile...") + compiled = torch.compile(fx_graph, dynamic=True, fullgraph=True) + + return compiled, param_names, buffer_names diff --git a/src/metatrain/pet/modules/structures.py b/src/metatrain/pet/modules/structures.py index 8438a7c67e..291556f9d8 100644 --- a/src/metatrain/pet/modules/structures.py +++ b/src/metatrain/pet/modules/structures.py @@ -154,6 +154,7 @@ def systems_to_batch( torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, Labels, ]: """ @@ -181,6 +182,8 @@ def systems_to_batch( - `reverse_neighbor_index`: The reversed neighbor list for each central atom - `cutoff_factors`: The cutoff function values for each edge - `system_indices`: The system index for each atom in the batch + - `neighbor_atom_indices`: The atom index of each neighbor in NEF format, + used to scatter edge-level gradients to per-atom forces in the compiled path - `sample_labels`: Labels indicating the system and atom indices for each atom """ @@ -296,6 +299,7 @@ def systems_to_batch( reversed_neighbor_list = compute_reversed_neighbor_list( nef_indices, corresponding_edges, nef_mask ) + neighbor_atom_indices = edge_array_to_nef(neighbors, nef_indices, nef_mask, 0.0) neighbors_index = edge_array_to_nef(neighbors, nef_indices).to(torch.int64) # Here, we compute the array that allows indexing into a flattened @@ -320,5 +324,6 @@ def systems_to_batch( reverse_neighbor_index, cutoff_factors, system_indices, + neighbor_atom_indices, sample_labels, ) diff --git a/src/metatrain/pet/modules/utilities.py b/src/metatrain/pet/modules/utilities.py index 5795c5aed8..2a3e88f4a3 100644 --- a/src/metatrain/pet/modules/utilities.py +++ b/src/metatrain/pet/modules/utilities.py @@ -52,6 +52,39 @@ def cutoff_func_cosine( return f +class DecomposedSiLU(torch.nn.Module): + """SiLU activation implemented as ``x * sigmoid(x)``. + + Unlike ``torch.nn.SiLU``, this decomposes into primitive ops so that + ``make_fx`` produces a backward graph without ``silu_backward`` nodes. + This is needed for ``torch.compile(inductor)`` to differentiate through + the inlined backward when using the FX compilation path for force training. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +def replace_silu_modules(module: torch.nn.Module) -> None: + """Replace all ``torch.nn.SiLU`` instances with :class:`DecomposedSiLU`. + + Recurses through the module tree, including inside ``nn.Sequential``. + + :param module: The module to recursively modify in-place. + """ + for name, child in module.named_children(): + if isinstance(child, torch.nn.SiLU): + setattr(module, name, DecomposedSiLU()) + elif isinstance(child, torch.nn.Sequential): + for i, layer in enumerate(child): + if isinstance(layer, torch.nn.SiLU): + child[i] = DecomposedSiLU() + else: + replace_silu_modules(layer) + else: + replace_silu_modules(child) + + class DummyModule(torch.nn.Module): """Dummy torch module to make torchscript happy. This model should never be run""" diff --git a/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v13.ckpt.gz b/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v13.ckpt.gz new file mode 100644 index 0000000000..c5c58fb9f2 Binary files /dev/null and b/src/metatrain/pet/tests/checkpoints/model-v11_trainer-v13.ckpt.gz differ diff --git a/src/metatrain/pet/tests/test_compile.py b/src/metatrain/pet/tests/test_compile.py new file mode 100644 index 0000000000..7ca028b824 --- /dev/null +++ b/src/metatrain/pet/tests/test_compile.py @@ -0,0 +1,226 @@ +"""Tests for torch.compile support in PET.""" + +import copy + +import pytest +import torch +from metatomic.torch import ModelOutput + +from metatrain.pet.modules.transformer import CartesianTransformer +from metatrain.utils.architectures import get_default_hypers +from metatrain.utils.testing import ArchitectureTests, TrainingTests + + +class PETTests(ArchitectureTests): + architecture = "pet" + + @pytest.fixture + def minimal_model_hypers(self): + hypers = get_default_hypers(self.architecture)["model"] + hypers = copy.deepcopy(hypers) + hypers["d_pet"] = 1 + hypers["d_head"] = 1 + hypers["d_node"] = 1 + hypers["d_feedforward"] = 1 + hypers["num_heads"] = 1 + hypers["num_attention_layers"] = 1 + hypers["num_gnn_layers"] = 1 + return hypers + + +def _make_cartesian_transformer(is_first=True, transformer_type="PreLN"): + """Helper to create a test CartesianTransformer.""" + return CartesianTransformer( + cutoff=4.5, + cutoff_width=0.5, + d_model=8, + n_head=2, + dim_node_features=16, + dim_feedforward=8, + n_layers=2, + norm="RMSNorm", + activation="SwiGLU", + attention_temperature=1.0, + transformer_type=transformer_type, + n_atomic_species=4, + is_first=is_first, + ) + + +def _make_inputs(n_atoms=5, max_neighbors=10, d_model=8, dim_node_features=16): + """Helper to create test inputs for CartesianTransformer.""" + input_node_embeddings = torch.randn(n_atoms, dim_node_features) + input_messages = torch.randn(n_atoms, max_neighbors, d_model) + element_indices_neighbors = torch.randint(0, 4, (n_atoms, max_neighbors)) + edge_vectors = torch.randn(n_atoms, max_neighbors, 3) + padding_mask = torch.ones(n_atoms, max_neighbors, dtype=torch.bool) + padding_mask[:, -3:] = False + edge_distances = torch.randn(n_atoms, max_neighbors).abs() + cutoff_factors = torch.rand(n_atoms, max_neighbors) + cutoff_factors[~padding_mask] = 0.0 + return ( + input_node_embeddings, + input_messages, + element_indices_neighbors, + edge_vectors, + padding_mask, + edge_distances, + cutoff_factors, + ) + + +def test_compile_cartesian_transformer(): + """Test CartesianTransformer with fullgraph=True and SDPA attention.""" + ct = _make_cartesian_transformer() + compiled_ct = torch.compile(ct, fullgraph=True) + + inputs = _make_inputs() + out_eager = ct(*inputs, False) + out_compiled = compiled_ct(*inputs, False) + + assert torch.allclose(out_eager[0], out_compiled[0], atol=1e-5) + assert torch.allclose(out_eager[1], out_compiled[1], atol=1e-5) + + +def test_compile_manual_attention(): + """Test that CartesianTransformer compiles with manual attention path.""" + ct = _make_cartesian_transformer() + compiled_ct = torch.compile(ct, fullgraph=True) + + inputs = _make_inputs() + out_eager = ct(*inputs, True) + out_compiled = compiled_ct(*inputs, True) + + assert torch.allclose(out_eager[0], out_compiled[0], atol=1e-5) + assert torch.allclose(out_eager[1], out_compiled[1], atol=1e-5) + + +def test_compile_backward(): + """Test that single backward through compiled CartesianTransformer works.""" + ct = _make_cartesian_transformer() + compiled_ct = torch.compile(ct, fullgraph=True) + + inputs = list(_make_inputs()) + inputs[3] = inputs[3].requires_grad_(True) # edge_vectors + + out = compiled_ct(*inputs, False) + loss = out[0].sum() + out[1].sum() + loss.backward() + + assert inputs[3].grad is not None + assert inputs[3].grad.shape == inputs[3].shape + + +def test_compile_not_first_layer(): + """Test compilation of non-first CartesianTransformer (different forward branch).""" + ct = _make_cartesian_transformer(is_first=False) + compiled_ct = torch.compile(ct, fullgraph=True) + + inputs = _make_inputs() + out_eager = ct(*inputs, False) + out_compiled = compiled_ct(*inputs, False) + + assert torch.allclose(out_eager[0], out_compiled[0], atol=1e-5) + assert torch.allclose(out_eager[1], out_compiled[1], atol=1e-5) + + +def test_compile_postln(): + """Test compilation with PostLN transformer type.""" + ct = _make_cartesian_transformer(transformer_type="PostLN") + compiled_ct = torch.compile(ct, fullgraph=True) + + inputs = _make_inputs() + out_eager = ct(*inputs, False) + out_compiled = compiled_ct(*inputs, False) + + assert torch.allclose(out_eager[0], out_compiled[0], atol=1e-5) + assert torch.allclose(out_eager[1], out_compiled[1], atol=1e-5) + + +def test_forward_from_batch(): + """Test that _forward_from_batch matches forward for per-atom energy.""" + from metatrain.pet import PET + from metatrain.utils.data import DatasetInfo + from metatrain.utils.data.readers import read_systems + from metatrain.utils.data.target_info import get_energy_target_info + from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + + from ..modules.structures import systems_to_batch + from . import DATASET_PATH, MODEL_HYPERS + + torch.manual_seed(42) + + targets = { + "mtt::U0": get_energy_target_info( + "mtt::U0", {"quantity": "energy", "unit": "eV"} + ) + } + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets + ) + model = PET(MODEL_HYPERS, dataset_info) + model.eval() + + systems = read_systems(DATASET_PATH)[:3] + systems = [s.to(torch.float32) for s in systems] + for s in systems: + get_system_with_neighbor_lists(s, model.requested_neighbor_lists()) + + # Get per-atom predictions from forward + forward_output = model( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=True)}, + ) + forward_per_atom = forward_output["mtt::U0"].block().values + + # Get per-atom predictions from _forward_from_batch + ( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + system_indices, + neighbor_atom_indices, + sample_labels, + ) = systems_to_batch( + systems, + model.requested_nl, + model.atomic_types, + model.species_to_species_index, + model.cutoff_function, + model.cutoff_width, + model.num_neighbors_adaptive, + ) + + batch_output = model._forward_from_batch( + element_indices_nodes, + element_indices_neighbors, + edge_vectors, + edge_distances, + padding_mask, + reverse_neighbor_index, + cutoff_factors, + ) + # Get the first (and only) block key for the energy target + energy_key = next(iter(model.output_shapes["mtt::U0"])) + batch_per_atom = batch_output["mtt::U0"][energy_key] + + torch.testing.assert_close(forward_per_atom, batch_per_atom, atol=1e-6, rtol=1e-6) + + +class TestTrainingCompile(TrainingTests, PETTests): + """Run the standard training tests with compile=True. + + The full-graph FX compilation path traces the entire PET model + (including force/stress computation) into a single FX graph and + compiles it with ``torch.compile(dynamic=True, fullgraph=True)``. + """ + + @pytest.fixture + def default_hypers(self): + hypers = get_default_hypers(self.architecture) + hypers["training"]["compile"] = True + return hypers diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index 8194adc7de..c114be5dcb 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -5,6 +5,8 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import System from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, DistributedSampler @@ -41,6 +43,137 @@ from .documentation import TrainerHypers from .model import PET from .modules.finetuning import apply_finetuning_strategy +from .modules.structures import systems_to_batch + + +def _wrap_compiled_output( + energy: torch.Tensor, + forces: torch.Tensor, + stress: torch.Tensor, + raw_predictions: Dict[str, Dict[str, torch.Tensor]], + model: PET, + systems: List[System], + sample_labels: Labels, + system_indices: torch.Tensor, + train_targets: Dict, +) -> Dict[str, TensorMap]: + """Convert compiled function outputs to Dict[str, TensorMap]. + + Produces the same format as ``evaluate_model`` so the loss function + and metric accumulators work unchanged. + + :param energy: Per-structure energy tensor from the compiled function. + :param forces: Per-atom force tensor, or ``None``. + :param stress: Per-structure stress tensor, or ``None``. + :param raw_predictions: Per-atom predictions keyed by target and block. + :param model: The PET model instance. + :param systems: The input systems for this batch. + :param sample_labels: Labels indicating system and atom indices. + :param system_indices: System index for each atom in the batch. + :param train_targets: Target information dict from the training config. + :return: Predictions as ``Dict[str, TensorMap]``. + """ + from metatrain.utils.sum_over_atoms import sum_over_atoms + + device = system_indices.device + predictions: Dict[str, TensorMap] = {} + + # Identify the energy target + energy_target_name = None + for tname in model.target_names: + if tname in model.outputs and model.outputs[tname].quantity == "energy": + energy_target_name = tname + break + + # Build energy TensorMap (per-structure) with optional gradient blocks + if energy_target_name is not None and energy is not None: + n_structures = energy.shape[0] + energy_block = TensorBlock( + values=energy.unsqueeze(-1), + samples=Labels( + "system", + torch.arange(n_structures, device=device, dtype=torch.int32).unsqueeze( + -1 + ), + assume_unique=True, + ), + components=[], + properties=Labels("energy", torch.tensor([[0]], device=device)), + ) + + if forces is not None: + # Position gradient block: samples are ["sample", "atom"] + # matching evaluate_model's _position_gradients_to_block format + grad_samples = Labels( + names=["sample", "atom"], + values=sample_labels.values.to(torch.int32), + assume_unique=True, + ).to(device) + xyz_labels = Labels("xyz", torch.tensor([[0], [1], [2]], device=device)) + forces_block = TensorBlock( + values=forces.unsqueeze(-1), + samples=grad_samples, + components=[xyz_labels], + properties=Labels("energy", torch.tensor([[0]], device=device)), + ) + energy_block.add_gradient("positions", forces_block) + + if stress is not None: + stress_samples = Labels( + "sample", + torch.arange(n_structures, device=device, dtype=torch.int32).unsqueeze( + -1 + ), + assume_unique=True, + ) + xyz1 = Labels("xyz_1", torch.tensor([[0], [1], [2]], device=device)) + xyz2 = Labels("xyz_2", torch.tensor([[0], [1], [2]], device=device)) + stress_block = TensorBlock( + values=stress.unsqueeze(-1), + samples=stress_samples, + components=[xyz1, xyz2], + properties=Labels("energy", torch.tensor([[0]], device=device)), + ) + energy_block.add_gradient("strain", stress_block) + + predictions[energy_target_name] = TensorMap( + keys=model.single_label, + blocks=[energy_block], + ) + + # Non-energy targets: wrap per-atom raw predictions into TensorMaps + for target_name in model.target_names: + if target_name == energy_target_name: + continue + if target_name not in raw_predictions: + continue + if target_name not in train_targets: + continue + + target_preds = raw_predictions[target_name] + blocks = [] + for key, shape, components, properties in zip( + model.output_shapes[target_name].keys(), + model.output_shapes[target_name].values(), + model.component_labels[target_name], + model.property_labels[target_name], + strict=True, + ): + values = target_preds[key].reshape([-1] + shape) + block = TensorBlock( + values=values, + samples=sample_labels, + components=components, + properties=properties, + ) + blocks.append(block) + + tmap = TensorMap(keys=model.key_labels[target_name], blocks=blocks) + if not train_targets[target_name].per_atom: + tmap = sum_over_atoms(tmap) + predictions[target_name] = tmap + + return predictions def get_scheduler( @@ -77,7 +210,7 @@ def lr_lambda(current_step: int) -> float: class Trainer(TrainerInterface[TrainerHypers]): - __checkpoint_version__ = 12 + __checkpoint_version__ = 13 def __init__(self, hypers: TrainerHypers) -> None: super().__init__(hypers) @@ -160,6 +293,26 @@ def train( additive_model.to(dtype=torch.float64) model.scaler.to(dtype=torch.float64) + # torch.compile: full-graph FX compilation of the entire model + # (including force/stress computation via autograd.grad). + compile_enabled = self.hypers.get("compile", False) + has_gradients = any( + len(target_info.gradients) > 0 + for target_info in model.dataset_info.targets.values() + ) + has_strain_gradients = any( + "strain" in target_info.gradients + for target_info in model.dataset_info.targets.values() + ) + if compile_enabled: + torch._dynamo.reset() + if is_distributed: + logging.warning( + "torch.compile with DDP is not yet supported. " + "Disabling compilation for distributed training." + ) + compile_enabled = False + logging.info("Calculating composition weights") model.additive_models[0].train_model( # this is the composition model train_datasets, @@ -357,6 +510,21 @@ def train( # Log the initial learning rate: logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + # Full-graph FX compilation (after dataloaders are ready for tracing). + compiled_fn = None + if compile_enabled: + from .modules.compile import compile_pet_model + + compiled_fn, _compiled_param_names, _compiled_buffer_names = ( + compile_pet_model( + model, + train_dataloader, + has_gradients, + has_strain_gradients, + ) + ) + logging.info("FX compilation complete (will optimize on first call)") + start_epoch = 0 if self.epoch is None else self.epoch + 1 # Train the model: @@ -389,27 +557,85 @@ def train( systems, targets, extra_data = batch_to( systems, targets, extra_data, dtype=dtype, device=device ) - predictions = evaluate_model( - model, - systems, - {key: train_targets[key] for key in targets.keys()}, - is_training=True, - ) - # average by the number of atoms - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets, extra_data) + if compile_enabled and compiled_fn is not None: + # FX-compiled path: call systems_to_batch directly, + # run the compiled function, and wrap outputs. + ( + c_element_indices_nodes, + c_element_indices_neighbors, + c_edge_vectors, + _c_edge_distances, + c_padding_mask, + c_reverse_neighbor_index, + c_cutoff_factors, + c_system_indices, + c_neighbor_atom_indices, + c_sample_labels, + ) = systems_to_batch( + systems, + model.requested_nl, + model.atomic_types, + model.species_to_species_index, + model.cutoff_function, + model.cutoff_width, + model.num_neighbors_adaptive, + ) + if has_gradients: + c_edge_vectors = c_edge_vectors.requires_grad_(True) + n_structures = len(systems) + energy, forces, stress, raw_preds = compiled_fn( + c_edge_vectors, + c_element_indices_nodes, + c_element_indices_neighbors, + c_padding_mask, + c_reverse_neighbor_index, + c_cutoff_factors, + c_system_indices, + c_neighbor_atom_indices, + n_structures, + *list(model.parameters()), + *list(model.buffers()), + ) + predictions = _wrap_compiled_output( + energy, + forces, + stress, + raw_preds, + model, + systems, + c_sample_labels, + c_system_indices, + train_targets, + ) + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + train_loss_batch = loss_fn(predictions, targets, extra_data) + train_loss_batch.backward() + else: + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=True, + ) + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + train_loss_batch = loss_fn(predictions, targets, extra_data) - if is_distributed: - # make sure all parameters contribute to the gradient calculation - # to make torch DDP happy - for param in model.parameters(): - train_loss_batch += 0.0 * param.sum() + if is_distributed: + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() - train_loss_batch.backward() + train_loss_batch.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), self.hypers["grad_clip_norm"] ) @@ -538,9 +764,13 @@ def train( ) if val_metric < self.best_metric: self.best_metric = val_metric - self.best_model_state_dict = copy.deepcopy( - (model.module if is_distributed else model).state_dict() - ) + raw_state_dict = ( + model.module if is_distributed else model + ).state_dict() + self.best_model_state_dict = { + k.replace("._orig_mod", ""): v.clone() + for k, v in raw_state_dict.items() + } self.best_epoch = epoch self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())