Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/src/dev-docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Unreleased
----------

Added
#####

- The PET architecture now supports full-graph FX compilation for training via the
``compile`` hyperparameter. When enabled, the entire model (including force/stress
computation) is traced into a single FX graph and compiled with ``torch.compile``,
providing maximum kernel fusion and consistently using scaled dot-product attention.

Version 2026.1 - 2026-01-07
---------------------------

Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ filterwarnings = [
"ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning",
"ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning",
"ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning",
"ignore:`torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning",
"ignore:`torch.jit.trace_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning",
"ignore:`torch.jit.trace` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning",
# PyTorch does not want these, but mypy requires them
Expand All @@ -177,7 +178,11 @@ filterwarnings = [
# Multi-threaded tests clash with multi-process data-loading
"ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning",
# MACE warning with newer versions of pytorch (because they use e3nn==0.4.4)
"ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning"
"ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning",
# compiled_autograd + non-leaf tensors: dynamo incorrectly converts this C++
# warning to an error when tracing backward graphs through non-leaf tensors
# (e.g. edge_distances computed from edge_vectors via sqrt)
"ignore:The .grad attribute of a Tensor that is not a leaf Tensor is being accessed:UserWarning"
]
addopts = ["-p", "mtt_plugin"]
pythonpath = "src/metatrain/utils/testing"
Expand Down
9 changes: 9 additions & 0 deletions src/metatrain/pet/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,12 @@ def trainer_update_v11_v12(checkpoint: dict) -> None:
:param checkpoint: The checkpoint to update.
"""
checkpoint["train_hypers"]["batch_atom_bounds"] = [None, None]


def trainer_update_v12_v13(checkpoint: dict) -> None:
"""
Update trainer checkpoint from version 12 to version 13.

:param checkpoint: The checkpoint to update.
"""
checkpoint["train_hypers"]["compile"] = False
11 changes: 11 additions & 0 deletions src/metatrain/pet/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,14 @@ class TrainerHypers(TypedDict):

See :ref:`label_fine_tuning_concept` for more details.
"""
compile: bool = False
"""Whether to use full-graph FX compilation during training.

When enabled, the entire PET model (including force/stress computation via
``autograd.grad``) is traced into a single FX graph using ``make_fx`` and
then compiled with ``torch.compile(dynamic=True, fullgraph=True)``. This
gives maximum kernel fusion, zero compiled/eager boundary crossings, and
always uses ``scaled_dot_product_attention`` (SDPA). Expect a one-time
compilation cost at the start of training, followed by speedups on every
subsequent step.
"""
81 changes: 79 additions & 2 deletions src/metatrain/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def forward(
reverse_neighbor_index,
cutoff_factors,
system_indices,
_neighbor_atom_indices,
sample_labels,
) = systems_to_batch(
systems,
Expand Down Expand Up @@ -570,6 +571,78 @@ def forward(

return return_dict

def _forward_from_batch(
self,
element_indices_nodes: torch.Tensor,
element_indices_neighbors: torch.Tensor,
edge_vectors: torch.Tensor,
edge_distances: torch.Tensor,
padding_mask: torch.Tensor,
reverse_neighbor_index: torch.Tensor,
cutoff_factors: torch.Tensor,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""
Pure-tensor forward pass for FX compilation.

Takes batch tensors and returns raw per-atom predictions as nested
dictionaries (target_name -> block_key -> tensor). Always uses SDPA
attention (no manual attention needed since forces are computed via
``autograd.grad(create_graph=False)`` in the compiled graph, avoiding
double backward).

:param element_indices_nodes: Atomic species of central atoms [n_atoms].
:param element_indices_neighbors: Atomic species of neighbors
[n_atoms, max_neighbors].
:param edge_vectors: Edge vectors [n_atoms, max_neighbors, 3].
:param edge_distances: Edge distances [n_atoms, max_neighbors].
:param padding_mask: Boolean mask for real neighbors
[n_atoms, max_neighbors].
:param reverse_neighbor_index: Reversed neighbor index for message
passing [n_atoms, max_neighbors].
:param cutoff_factors: Cutoff function values [n_atoms, max_neighbors].
:return: Nested dict mapping target_name -> block_key -> per-atom
prediction tensor.
"""
featurizer_inputs: Dict[str, torch.Tensor] = dict(
element_indices_nodes=element_indices_nodes,
element_indices_neighbors=element_indices_neighbors,
edge_vectors=edge_vectors,
edge_distances=edge_distances,
reverse_neighbor_index=reverse_neighbor_index,
padding_mask=padding_mask,
cutoff_factors=cutoff_factors,
)

# Always use SDPA (no double backward in compiled path)
node_features_list, edge_features_list = self._calculate_features(
featurizer_inputs, use_manual_attention=False
)

node_ll_dict, edge_ll_dict = self._calculate_last_layer_features(
node_features_list, edge_features_list
)

outputs_all: Dict[str, ModelOutput] = {
name: ModelOutput(per_atom=True) for name in self.target_names
}
node_preds, edge_preds = self._calculate_atomic_predictions(
node_ll_dict, edge_ll_dict, padding_mask, cutoff_factors, outputs_all
)

# Sum across GNN layers for each target/block
results: Dict[str, Dict[str, torch.Tensor]] = {}
for target_name in self.target_names:
block_results: Dict[str, torch.Tensor] = {}
node_layers = node_preds[target_name]
edge_layers = edge_preds[target_name]
for j, key in enumerate(self.output_shapes[target_name]):
total = node_layers[0][j] + edge_layers[0][j]
for i in range(1, len(node_layers)):
total = total + node_layers[i][j] + edge_layers[i][j]
block_results[key] = total
results[target_name] = block_results
return results

def _calculate_features(
self, inputs: Dict[str, torch.Tensor], use_manual_attention: bool
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
Expand Down Expand Up @@ -1340,7 +1413,11 @@ def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict:
return checkpoint

def get_checkpoint(self) -> Dict:
model_state_dict = self.state_dict()
# Get state dict, handling compiled modules by removing _orig_mod prefix
state_dict = {
k.replace("._orig_mod", ""): v for k, v in self.state_dict().items()
}
model_state_dict = dict(state_dict)
model_state_dict["finetune_config"] = self.finetune_config
checkpoint = {
"architecture_name": "pet",
Expand All @@ -1353,7 +1430,7 @@ def get_checkpoint(self) -> Dict:
"epoch": None,
"best_epoch": None,
"model_state_dict": model_state_dict,
"best_model_state_dict": self.state_dict(),
"best_model_state_dict": state_dict,
}
return checkpoint

Expand Down
Loading
Loading