From 6be4b78b7ac1a0b4d6ed937cbbff55775436f58d Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Fri, 16 Jan 2026 16:07:49 +0100 Subject: [PATCH 1/5] Make MACE work with LLPR --- src/metatrain/experimental/mace/model.py | 43 +++++++++++++++++------- src/metatrain/llpr/model.py | 39 ++++++++++++++++----- src/metatrain/utils/testing/output.py | 5 +++ 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/src/metatrain/experimental/mace/model.py b/src/metatrain/experimental/mace/model.py index f30a5968cc..a9eec78a89 100644 --- a/src/metatrain/experimental/mace/model.py +++ b/src/metatrain/experimental/mace/model.py @@ -1,3 +1,4 @@ +import copy import logging import warnings from pathlib import Path @@ -130,7 +131,7 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: if hasattr(self.mace_model, "atomic_energies_fn"): self._loaded_atomic_baseline = ( self.mace_model.atomic_energies_fn.atomic_energies.clone() - ) + ).ravel() self.mace_model.atomic_energies_fn.atomic_energies[:] = 0.0 @@ -574,6 +575,22 @@ def load_checkpoint( return model + def _get_capabilities(self) -> ModelCapabilities: + dtype = next(self.parameters()).dtype + + interaction_range = self.hypers["num_interactions"] * self.cutoff + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + return capabilities + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: @@ -588,20 +605,18 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # be registered correctly with Pytorch. This function moves them: self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_range = self.hypers["num_interactions"] * self.cutoff - - capabilities = ModelCapabilities( - outputs=self.outputs, - atomic_types=self.atomic_types, - interaction_range=interaction_range, - length_unit=self.dataset_info.length_unit, - supported_devices=self.__supported_devices__, - dtype=dtype_to_str(dtype), - ) + capabilities = self._get_capabilities() metadata = merge_metadata(self.metadata, metadata) - return AtomisticModel(jit.compile(self.eval()), metadata, capabilities) + if self.hypers["mace_model"] is not None: + to_export = copy.deepcopy(self.eval()) + else: + to_export = self.eval() + + model = AtomisticModel(jit.compile(to_export), metadata, capabilities) + + return model def _add_output(self, target_name: str, target_info: TargetInfo) -> None: """ @@ -625,6 +640,10 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: self.heads[target_name] = MACEHeadWrapper( self.mace_model.readouts, self.per_layer_irreps ) + + self.last_layer_feature_size = self.heads[ + target_name + ].last_layer_features_irreps.dim else: head = NonLinearHead( irreps_in=self.features_irreps, diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 4918946e88..7a376b1bd1 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -292,14 +292,15 @@ def forward( # special case for energy_ensemble ll_features_name = "mtt::aux::energy_last_layer_features" ll_features = return_dict[ll_features_name] + ll_feats_vals = self.get_ll_feats_vals(ll_features) # compute PRs # the code is the same for PR and LPR one_over_pr_values = torch.einsum( "ij, jk, ik -> i", - ll_features.block().values, + ll_feats_vals, self._get_inv_covariance(uncertainty_name), - ll_features.block().values, + ll_feats_vals, ).unsqueeze(1) original_name = self._get_original_name(uncertainty_name) @@ -323,7 +324,7 @@ def forward( TensorBlock( values=torch.sqrt(one_over_pr_values.expand((-1, num_prop))), samples=ll_features.block().samples, - components=ll_features.block().components, + components=[], properties=cur_prop, ) ], @@ -344,7 +345,7 @@ def forward( one_over_pr_values.shape[0], num_prop ), samples=ll_features.block().samples, - components=ll_features.block().components, + components=[], properties=cur_prop, ) ], @@ -434,7 +435,7 @@ def forward( TensorBlock( values=ensemble_values, samples=ll_features.block().samples, - components=ll_features.block().components, + components=[], properties=ens_prop, ), ], @@ -486,9 +487,9 @@ class in ``metatrain``. # TODO: interface ll_feat calculation with the loss function, # paying attention to normalization w.r.t. n_atoms if not outputs_for_targets[name].per_atom: - ll_feats = ll_feat_tmap.block().values.detach() / n_atoms.unsqueeze( - 1 - ) + ll_feats = self.get_ll_feats_vals( + ll_feat_tmap + ).detach() / n_atoms.unsqueeze(1) uncertainty_name = _get_uncertainty_name(name) covariance = self._get_covariance(uncertainty_name) covariance += ll_feats.T @ ll_feats @@ -608,6 +609,26 @@ def calibrate(self, valid_loader: DataLoader) -> None: multiplier = self._get_multiplier(uncertainty_name) multiplier[:] = torch.sqrt(torch.mean(ratios, dim=0)) # only along samples + def get_ll_feats_vals(self, ll_features_tmap: TensorMap) -> torch.Tensor: + """Get the last-layer features values from a TensorMap. + + It ensures that the last-layer features are scalars and + makes sure that the values have no component dimension. + + :param ll_features_tmap: A TensorMap containing last-layer features. + :return: A tensor with the last-layer features values. + """ + block = ll_features_tmap.block() + block_shape = block.values.shape + if len(block_shape) > 3 or (len(block_shape) == 3 and block_shape[1] != 1): + raise ValueError( + "Can't use last layer features other than scalars. Received " + f"last layer features are:\n{block}" + ) + # Make sure there is no component dimension + ll_feats_vals = block.values.reshape(block_shape[0], -1) + return ll_feats_vals + def generate_ensemble(self) -> None: """Generate an ensemble of weights for the model. @@ -770,6 +791,8 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # no weights to move pass + self.model = self.model.export().module + metadata = merge_metadata( merge_metadata(self.__default_metadata__, metadata), self.model.export().metadata(), diff --git a/src/metatrain/utils/testing/output.py b/src/metatrain/utils/testing/output.py index 929d2e3d15..4e2b6e37b9 100644 --- a/src/metatrain/utils/testing/output.py +++ b/src/metatrain/utils/testing/output.py @@ -518,6 +518,11 @@ def test_output_last_layer_features( model = self.model_cls(model_hypers, dataset_info) + assert hasattr(model, "last_layer_feature_size"), ( + f"{self.architecture} does not have the attribute " + "`last_layer_feature_size`." + ) + system = System( types=torch.tensor([6, 1, 8, 7]), positions=torch.tensor( From 9b6ecf2327a81d7061e8c495837e6ada42795c84 Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Fri, 16 Jan 2026 16:13:58 +0100 Subject: [PATCH 2/5] Make sure the LLPR object is not modified on export --- src/metatrain/llpr/model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 7a376b1bd1..d7a2b6a857 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -1,3 +1,4 @@ +import copy import logging from typing import Any, Dict, List, Literal, Optional @@ -791,14 +792,20 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # no weights to move pass - self.model = self.model.export().module + # Make sure the wrapped model is ready to be exported, for that + # we shallow-copy the LLPR model and replace the wrapped model + # in this shallow copy with the exported version of the + # wrapped model + to_export = copy.copy(self.eval()) + exported_wrapped = self.model.export() + to_export.model = exported_wrapped.module metadata = merge_metadata( merge_metadata(self.__default_metadata__, metadata), - self.model.export().metadata(), + exported_wrapped.metadata(), ) - return AtomisticModel(self.eval(), metadata, self.capabilities) + return AtomisticModel(to_export, metadata, self.capabilities) def _get_covariance(self, name: str) -> torch.Tensor: name = "covariance_" + name From 2531ac057c9a569e672983a418a659863d8a00fb Mon Sep 17 00:00:00 2001 From: Pol Febrer Calabozo Date: Tue, 20 Jan 2026 14:55:39 +0100 Subject: [PATCH 3/5] Fix MACE export on GPU --- src/metatrain/experimental/mace/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/metatrain/experimental/mace/model.py b/src/metatrain/experimental/mace/model.py index a9eec78a89..834645f0b4 100644 --- a/src/metatrain/experimental/mace/model.py +++ b/src/metatrain/experimental/mace/model.py @@ -610,7 +610,8 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: metadata = merge_metadata(self.metadata, metadata) if self.hypers["mace_model"] is not None: - to_export = copy.deepcopy(self.eval()) + to_export = copy.copy(self.eval().to(device="cpu")) + to_export.mace_model = copy.deepcopy(to_export.mace_model) else: to_export = self.eval() From c988e6533c7a25c8579c1bd1e6a4deba1b60d3a8 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 21 Jan 2026 05:09:28 +0100 Subject: [PATCH 4/5] Hack for direct-force MACE LLPR --- src/metatrain/experimental/mace/model.py | 6 +- src/metatrain/llpr/model.py | 501 ++++++++++++++++------- 2 files changed, 356 insertions(+), 151 deletions(-) diff --git a/src/metatrain/experimental/mace/model.py b/src/metatrain/experimental/mace/model.py index 834645f0b4..4e13e277d8 100644 --- a/src/metatrain/experimental/mace/model.py +++ b/src/metatrain/experimental/mace/model.py @@ -635,16 +635,14 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: self.layouts[target_name] = target_info.layout + self.last_layer_feature_size = 128 + if target_name == self.hypers["mace_head_target"]: # Fake head that will not compute the target, but will help # us extract the last layer features from MACE internal head. self.heads[target_name] = MACEHeadWrapper( self.mace_model.readouts, self.per_layer_irreps ) - - self.last_layer_feature_size = self.heads[ - target_name - ].last_layer_features_irreps.dim else: head = NonLinearHead( irreps_in=self.features_irreps, diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index d7a2b6a857..1dd2fb036f 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -1,6 +1,5 @@ -import copy import logging -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, Iterator, List, Literal, Optional, Union import metatensor.torch as mts import numpy as np @@ -16,10 +15,20 @@ from torch.utils.data import DataLoader from metatrain.utils.abc import ModelInterface -from metatrain.utils.data import DatasetInfo, unpack_batch +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + DatasetInfo, + unpack_batch, +) from metatrain.utils.data.target_info import is_auxiliary_output from metatrain.utils.io import model_from_checkpoint from metatrain.utils.metadata import merge_metadata +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) from . import checkpoints from .documentation import ModelHypers @@ -198,9 +207,14 @@ def set_wrapped_model(self, model: ModelInterface) -> None: self.llpr_ensemble_layers = torch.nn.ModuleDict() for name, value in self.ensemble_weight_sizes.items(): # create the linear layer for ensemble members + tensor_names = self.model.last_layer_parameter_names[name] + n_properties = torch.concatenate( + [self.model.state_dict()[tn] for tn in tensor_names], + axis=-1, + ).shape[0] # type: ignore self.llpr_ensemble_layers[name] = torch.nn.Linear( self.ll_feat_size, - value, + value * n_properties, bias=False, ) @@ -236,6 +250,85 @@ def restart(self, dataset_info: DatasetInfo) -> "LLPRUncertaintyModel": return self + def _get_dataloader( + self, + datasets: List[Union[Dataset, torch.utils.data.Subset]], + batch_size: int, + is_distributed: bool, + ) -> DataLoader: + """ + Create a DataLoader for the provided datasets. As the dataloader is only used to + accumulate the quantities needed for LLPR calibration, there is no need to + shuffle or drop the last non-full batch. Distributed sampling can be used or + not, based on the `is_distributed` argument, and training with double + precision is enforced. + + :param datasets: List of datasets to create the dataloader from. + :param batch_size: Batch size to use for the dataloader. + :param is_distributed: Whether to use distributed sampling or not. + :return: The created DataLoader. + """ + # Create the collate function + targets_keys = list(self.dataset_info.targets.keys()) + requested_neighbor_lists = get_requested_neighbor_lists(self) + collate_fn = CollateFn( + target_keys=targets_keys, + callables=[ + get_system_with_neighbor_lists_transform(requested_neighbor_lists) + ], + ) + + # Validate dtype from datasets + if len(datasets) == 0: + raise ValueError( + "Cannot create dataloader from empty datasets list. " + "Please provide non-empty datasets for LLPR calibration." + ) + if len(datasets[0]) == 0: + raise ValueError( + "Cannot create dataloader from empty dataset. " + "Please provide non-empty datasets for LLPR calibration." + ) + + # Build the dataloaders + samplers: List[torch.utils.data.Sampler | None] + if is_distributed: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + samplers = [ + NoPadDistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + seed=0, + ) + for dataset in datasets + ] + else: + samplers = [None] * len(datasets) + + dataloaders = [] + for dataset, sampler in zip(datasets, samplers, strict=True): + if len(dataset) < batch_size: + raise ValueError( + f"A dataset has fewer samples " + f"({len(dataset)}) than the batch size " + f"({batch_size}). " + "Please reduce the batch size." + ) + dataloaders.append( + DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + drop_last=False, + collate_fn=collate_fn, + ) + ) + + return CombinedDataLoader(dataloaders, shuffle=True) + def forward( self, systems: List[System], @@ -293,16 +386,23 @@ def forward( # special case for energy_ensemble ll_features_name = "mtt::aux::energy_last_layer_features" ll_features = return_dict[ll_features_name] - ll_feats_vals = self.get_ll_feats_vals(ll_features) - # compute PRs - # the code is the same for PR and LPR - one_over_pr_values = torch.einsum( - "ij, jk, ik -> i", - ll_feats_vals, - self._get_inv_covariance(uncertainty_name), - ll_feats_vals, - ).unsqueeze(1) + # compute PRs + # the code is the same for PR and LPR + if ll_features.block().values.ndim == 3: + one_over_pr_values = torch.einsum( + "icj, jk, ick -> i", + ll_features.block().values, + self._get_inv_covariance(uncertainty_name), + ll_features.block().values, + ).unsqueeze(1) + else: + one_over_pr_values = torch.einsum( + "ij, jk, ik -> i", + ll_features.block().values, + self._get_inv_covariance(uncertainty_name), + ll_features.block().values, + ).unsqueeze(1) original_name = self._get_original_name(uncertainty_name) @@ -377,16 +477,36 @@ def forward( # raw ens output shape is (samples, (num_ens * num_prop)) ensemble_values = module(ll_features.block().values) - # extract property labels and shape - cur_prop = return_dict[original_name].block().properties - num_prop = len(cur_prop.values) + # extract shape of components and properties + components_shape = list( + return_dict[original_name].block().values.shape[1:-1] + ) + num_prop = return_dict[original_name].block().values.shape[-1] # reshape values accordingly - ensemble_values = ensemble_values.reshape( - ensemble_values.shape[0], - -1, # num_ens - num_prop, - ) # shape: samples, num_ens, num_prop + if num_prop == ensemble_values.shape[-1]: + # equivariant (or unconstrained with single component) + ensemble_values = ensemble_values.reshape( + [ensemble_values.shape[0]] + components_shape + [-1, num_prop] + ) # shape: samples, ..., num_ens, num_prop + else: + # unconstrained with multiple components + ensemble_values = ensemble_values.reshape( + [ensemble_values.shape[0], -1] + components_shape + [num_prop] + ) # shape: samples, num_ens, ..., num_prop + # move num_ens to position before num_prop (-2) + ensemble_values = ( + ensemble_values.reshape( + ensemble_values.shape[0], + -1, + _prod(components_shape), + num_prop, + ) + .swapaxes(1, 2) + .reshape( + [ensemble_values.shape[0]] + components_shape + [-1, num_prop] + ) + ) # shape: samples, ..., num_ens, num_prop # since we know the exact mean of the ensemble from the model's prediction, # it should be mathematically correct to use it to re-center the ensemble. @@ -397,19 +517,19 @@ def forward( # last layer, etc. ensemble_values = ( ensemble_values - - ensemble_values.mean(dim=1, keepdim=True) - + return_dict[original_name].block().values.unsqueeze(1) # ens_dim + - ensemble_values.mean(dim=-2, keepdim=True) + + return_dict[original_name].block().values.unsqueeze(-2) # ens_dim ) + num_ens = ensemble_values.shape[-2] + ensemble_values = ensemble_values.reshape( - ensemble_values.shape[0], - -1, - ) # shape: (samples, (num_ens * num_prop)) + [ensemble_values.shape[0]] + components_shape + [-1] + ) # shape: (samples, components, (num_ens * num_prop)) # prepare the properties Labels object for ensemble output, i.e. account # for the num_ens dimension old_prop_val = return_dict[original_name].block().properties.values - num_ens = ensemble_values.shape[1] num_samples = old_prop_val.shape[0] exp_prop_val = old_prop_val.repeat(num_ens, 1) ens_idxs = torch.arange( @@ -436,7 +556,7 @@ def forward( TensorBlock( values=ensemble_values, samples=ll_features.block().samples, - components=[], + components=return_dict[original_name].block().components, properties=ens_prop, ), ], @@ -453,47 +573,72 @@ def forward( return return_dict - def compute_covariance(self, train_loader: DataLoader) -> None: + def compute_covariance( + self, + datasets: List[Union[Dataset, torch.utils.data.Subset]], + batch_size: int, + is_distributed: bool, + ) -> None: """A function to compute the covariance matrix for a training set. The covariance is stored as a buffer in the model. - :param train_loader: A PyTorch DataLoader with the training data. - The individual samples need to be compatible with the ``Dataset`` - class in ``metatrain``. + :param datasets: List of datasets to use for covariance calculation. + :param batch_size: Batch size to use for the dataloader. + :param is_distributed: Whether to use distributed sampling or not. """ + # Create dataloader for the training datasets + train_loader = self._get_dataloader( + datasets, batch_size, is_distributed=is_distributed + ) + device = next(iter(self.buffers())).device dtype = next(iter(self.buffers())).dtype - for batch in train_loader: - systems, targets, extra_data = unpack_batch(batch) - n_atoms = torch.tensor( - [len(system.positions) for system in systems], device=device - ) - systems = [system.to(device=device, dtype=dtype) for system in systems] - outputs_for_targets = { - name: ModelOutput(per_atom="atom" in target.block(0).samples.names) - for name, target in targets.items() - } - outputs_for_features = { - f"mtt::aux::{n.replace('mtt::', '')}_last_layer_features": o - for n, o in outputs_for_targets.items() - } - output = self.forward( - systems, {**outputs_for_targets, **outputs_for_features} - ) - for name in targets.keys(): - ll_feat_tmap = output[ - f"mtt::aux::{name.replace('mtt::', '')}_last_layer_features" - ] - # TODO: interface ll_feat calculation with the loss function, - # paying attention to normalization w.r.t. n_atoms - if not outputs_for_targets[name].per_atom: - ll_feats = self.get_ll_feats_vals( - ll_feat_tmap - ).detach() / n_atoms.unsqueeze(1) + with torch.no_grad(): + for batch in train_loader: + systems, targets, _ = unpack_batch(batch) + n_atoms = torch.tensor( + [len(system.positions) for system in systems], device=device + ) + systems = [system.to(device=device, dtype=dtype) for system in systems] + outputs_for_targets = { + name: ModelOutput(per_atom="atom" in target.block(0).samples.names) + for name, target in targets.items() + } + outputs_for_features = { + f"mtt::aux::{n.replace('mtt::', '')}_last_layer_features": o + for n, o in outputs_for_targets.items() + } + output = self.forward( + systems, {**outputs_for_targets, **outputs_for_features} + ) + for name in targets.keys(): + ll_feat_tmap = output[ + f"mtt::aux::{name.replace('mtt::', '')}_last_layer_features" + ] + # TODO: interface ll_feat calculation with the loss function, + # paying attention to normalization w.r.t. n_atoms + if not outputs_for_targets[name].per_atom: + ll_feats = ( + ll_feat_tmap.block().values.detach() / n_atoms.unsqueeze(1) + ) + else: + # For per-atom targets, use the features directly + ll_feats = ll_feat_tmap.block().values.detach() + if ll_feats.ndim > 2: + # flatten component dimensions into samples + ll_feats = ll_feats.reshape(-1, ll_feats.shape[-1]) + uncertainty_name = _get_uncertainty_name(name) + covariance = self._get_covariance(uncertainty_name) + covariance += ll_feats.T @ ll_feats + + if is_distributed: + torch.distributed.barrier() + # All-reduce the covariance matrices across all processes + for name in self.outputs_list: uncertainty_name = _get_uncertainty_name(name) covariance = self._get_covariance(uncertainty_name) - covariance += ll_feats.T @ ll_feats + torch.distributed.all_reduce(covariance) def compute_inverse_covariance(self, regularizer: Optional[float] = None) -> None: """A function to compute the inverse covariance matrix. @@ -538,7 +683,13 @@ def is_psd(x: torch.Tensor) -> torch.Tensor: inv_covariance[:] = (inverse + inverse.T) / 2.0 break - def calibrate(self, valid_loader: DataLoader) -> None: + def calibrate( + self, + datasets: List[Union[Dataset, torch.utils.data.Subset]], + batch_size: int, + is_distributed: bool, + use_absolute_residuals: bool, + ) -> None: """ Calibrate the LLPR model. @@ -548,87 +699,104 @@ def calibrate(self, valid_loader: DataLoader) -> None: constant as the mean of the squared residuals divided by the mean of the non-calibrated uncertainties. - :param valid_loader: A data loader with the validation data. - This data loader should be generated from a dataset from the - ``Dataset`` class in ``metatrain.utils.data``. + :param datasets: List of datasets to use for calibration. + :param batch_size: Batch size to use for the dataloader. + :param is_distributed: Whether to use distributed sampling or not. + :param use_absolute_residuals: Whether to use absolute residuals as opposed + to squared residuals for the calibration. In both cases, a Gaussian + error distribution is assumed in order to derive the calibration constants, + but using absolute residuals can help reduce the effect of large outliers. """ - # calibrate the LLPR + # Create dataloader for the validation datasets + valid_loader = self._get_dataloader( + datasets, batch_size, is_distributed=is_distributed + ) + + # infer device and dtype device = next(iter(self.buffers())).device dtype = next(iter(self.buffers())).dtype - all_predictions = {} # type: ignore - all_targets = {} # type: ignore - all_uncertainties = {} # type: ignore - - for batch in valid_loader: - systems, targets, extra_data = unpack_batch(batch) - systems = [system.to(device=device, dtype=dtype) for system in systems] - targets = { - name: target.to(device=device, dtype=dtype) - for name, target in targets.items() - } - requested_outputs = {} - for name in targets: - per_atom = "atom" in targets[name].block(0).samples.names - requested_outputs[name] = ModelOutput(per_atom=per_atom) - uncertainty_name = _get_uncertainty_name(name) - requested_outputs[uncertainty_name] = ModelOutput(per_atom=per_atom) - outputs = self.forward(systems, requested_outputs) - for name, target in targets.items(): - uncertainty_name = _get_uncertainty_name(name) - if name not in all_predictions: - all_predictions[name] = [] - all_targets[name] = [] - all_uncertainties[uncertainty_name] = [] - all_predictions[name].append(outputs[name].block().values.detach()) - all_targets[name].append(target.block().values) - all_uncertainties[uncertainty_name].append( - outputs[uncertainty_name].block().values.detach() - ) - for name in all_predictions: - all_predictions[name] = torch.cat(all_predictions[name], dim=0) - all_targets[name] = torch.cat(all_targets[name], dim=0) - uncertainty_name = _get_uncertainty_name(name) - all_uncertainties[uncertainty_name] = torch.cat( - all_uncertainties[uncertainty_name], dim=0 - ) - - for name in all_predictions: - # compute the uncertainty multiplier - residuals = all_predictions[name] - all_targets[name] - squared_residuals = residuals**2 - if squared_residuals.ndim > 2: - # squared residuals need to be summed over component dimensions, - # i.e., all but the first and last dimensions - squared_residuals = torch.sum( - squared_residuals, - dim=tuple(range(1, squared_residuals.ndim - 1)), - ) - uncertainty_name = _get_uncertainty_name(name) - uncertainties = all_uncertainties[uncertainty_name] - ratios = squared_residuals / uncertainties**2 # can be multi-dimensional - multiplier = self._get_multiplier(uncertainty_name) - multiplier[:] = torch.sqrt(torch.mean(ratios, dim=0)) # only along samples + sums = {} # type: ignore + counts = {} # type: ignore + + with torch.no_grad(): + for batch in valid_loader: + systems, targets, _ = unpack_batch(batch) + systems = [system.to(device=device, dtype=dtype) for system in systems] + targets = { + name: target.to(device=device, dtype=dtype) + for name, target in targets.items() + } + requested_outputs = {} + for name in targets: + per_atom = "atom" in targets[name].block(0).samples.names + requested_outputs[name] = ModelOutput(per_atom=per_atom) + uncertainty_name = _get_uncertainty_name(name) + requested_outputs[uncertainty_name] = ModelOutput(per_atom=per_atom) + + outputs = self.forward(systems, requested_outputs) + + for name, target in targets.items(): + uncertainty_name = _get_uncertainty_name(name) + + pred = outputs[name].block().values.detach() + targ = target.block().values + unc = outputs[uncertainty_name].block().values.detach() + + # compute the uncertainty multiplier + residuals = pred - targ + squared_residuals = residuals**2 + if squared_residuals.ndim > 2: + # squared residuals need to be summed over component dimensions, + # i.e., all but the first and last dimensions + squared_residuals = torch.sum( + squared_residuals, + dim=tuple(range(1, squared_residuals.ndim - 1)), + ) - def get_ll_feats_vals(self, ll_features_tmap: TensorMap) -> torch.Tensor: - """Get the last-layer features values from a TensorMap. + if use_absolute_residuals: + ratios = torch.sqrt(squared_residuals) / unc + else: + ratios = squared_residuals / (unc**2) - It ensures that the last-layer features are scalars and - makes sure that the values have no component dimension. + ratios_sum64 = torch.sum(ratios.to(torch.float64), dim=0) + count = torch.tensor( + ratios.shape[0], dtype=torch.long, device=device + ) - :param ll_features_tmap: A TensorMap containing last-layer features. - :return: A tensor with the last-layer features values. - """ - block = ll_features_tmap.block() - block_shape = block.values.shape - if len(block_shape) > 3 or (len(block_shape) == 3 and block_shape[1] != 1): - raise ValueError( - "Can't use last layer features other than scalars. Received " - f"last layer features are:\n{block}" - ) - # Make sure there is no component dimension - ll_feats_vals = block.values.reshape(block_shape[0], -1) - return ll_feats_vals + if uncertainty_name not in sums: + sums[uncertainty_name] = ratios_sum64 + counts[uncertainty_name] = count + else: + sums[uncertainty_name] = sums[uncertainty_name] + ratios_sum64 + counts[uncertainty_name] = counts[uncertainty_name] + count + + if is_distributed: + # All-reduce the accumulated statistics across all processes + for uncertainty_name in sums: + torch.distributed.all_reduce(sums[uncertainty_name]) + torch.distributed.all_reduce(counts[uncertainty_name]) + + for uncertainty_name in sums: + if use_absolute_residuals: + # "MAE"-style calibration + global_mean64 = sums[uncertainty_name] / counts[uncertainty_name].to( + torch.float64 + ) + else: + # "RMSE"-style calibration + global_mean64 = torch.sqrt( + sums[uncertainty_name] / counts[uncertainty_name].to(torch.float64) + ) + multiplier = self._get_multiplier(uncertainty_name) + if use_absolute_residuals: + # apply absolute correction factor (inverse of integral of abs(x) + # over Gaussian, i.e., sqrt(pi/2)) + multiplier[:] = (global_mean64 * np.sqrt(np.pi / 2.0)).to( + multiplier.dtype + ) + else: + multiplier[:] = global_mean64.to(multiplier.dtype) def generate_ensemble(self) -> None: """Generate an ensemble of weights for the model. @@ -664,14 +832,19 @@ def generate_ensemble(self) -> None: .cpu() .numpy() ) - rng = np.random.default_rng() + rng = np.random.default_rng(42) ensemble_weights = [] for ii in range(weights.shape[0]): + # TODO: this isn't good enough for multi-target equivariant + if cur_multiplier.shape[0] > 1: # unconstrained models + multiplier = cur_multiplier[ii].item() + else: # equivariant models + multiplier = cur_multiplier.item() cur_ensemble_weights = rng.multivariate_normal( weights[ii].clone().detach().cpu().numpy(), - cur_inv_covariance * cur_multiplier[ii].item() ** 2, + cur_inv_covariance * multiplier**2, size=self.ensemble_weight_sizes[name], method="svd", ).T @@ -792,20 +965,12 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # no weights to move pass - # Make sure the wrapped model is ready to be exported, for that - # we shallow-copy the LLPR model and replace the wrapped model - # in this shallow copy with the exported version of the - # wrapped model - to_export = copy.copy(self.eval()) - exported_wrapped = self.model.export() - to_export.model = exported_wrapped.module - metadata = merge_metadata( merge_metadata(self.__default_metadata__, metadata), - exported_wrapped.metadata(), + self.model.export().metadata(), ) - return AtomisticModel(to_export, metadata, self.capabilities) + return AtomisticModel(self.eval(), metadata, self.capabilities) def _get_covariance(self, name: str) -> torch.Tensor: name = "covariance_" + name @@ -878,3 +1043,45 @@ def _get_uncertainty_name(name: str) -> str: else: uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty" return uncertainty_name + + +class NoPadDistributedSampler(torch.utils.data.Sampler[int]): + def __init__( + self, + dataset: torch.utils.data.Dataset, + num_replicas: int, + rank: int, + shuffle: bool = False, + seed: int = 0, + ): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.shuffle = shuffle + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + def __iter__(self) -> Iterator[int]: + n = len(self.dataset) + indices = torch.arange(n, dtype=torch.long) + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = indices[torch.randperm(n, generator=g)] + # Key property: no padding, no dropping + return iter(indices[self.rank :: self.num_replicas].tolist()) + + def __len__(self) -> int: + n = len(self.dataset) + return (n - self.rank + self.num_replicas - 1) // self.num_replicas + + +def _prod(list_of_int: List[int]) -> int: + # for torchscript compatibility (math.prod is not supported) + result = 1 + for x in list_of_int: + result = result * x + return result From ecc86f6864e70c970f92acedf303e75dee1d899a Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 21 Jan 2026 05:15:26 +0100 Subject: [PATCH 5/5] Fix linter --- src/metatrain/llpr/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 1dd2fb036f..3f057ec207 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -387,8 +387,8 @@ def forward( ll_features_name = "mtt::aux::energy_last_layer_features" ll_features = return_dict[ll_features_name] - # compute PRs - # the code is the same for PR and LPR + # compute PRs + # the code is the same for PR and LPR if ll_features.block().values.ndim == 3: one_over_pr_values = torch.einsum( "icj, jk, ick -> i",