diff --git a/src/metatrain/experimental/mace/model.py b/src/metatrain/experimental/mace/model.py index f30a5968cc..4e13e277d8 100644 --- a/src/metatrain/experimental/mace/model.py +++ b/src/metatrain/experimental/mace/model.py @@ -1,3 +1,4 @@ +import copy import logging import warnings from pathlib import Path @@ -130,7 +131,7 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: if hasattr(self.mace_model, "atomic_energies_fn"): self._loaded_atomic_baseline = ( self.mace_model.atomic_energies_fn.atomic_energies.clone() - ) + ).ravel() self.mace_model.atomic_energies_fn.atomic_energies[:] = 0.0 @@ -574,6 +575,22 @@ def load_checkpoint( return model + def _get_capabilities(self) -> ModelCapabilities: + dtype = next(self.parameters()).dtype + + interaction_range = self.hypers["num_interactions"] * self.cutoff + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + return capabilities + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: @@ -588,20 +605,19 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: # be registered correctly with Pytorch. This function moves them: self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) - interaction_range = self.hypers["num_interactions"] * self.cutoff - - capabilities = ModelCapabilities( - outputs=self.outputs, - atomic_types=self.atomic_types, - interaction_range=interaction_range, - length_unit=self.dataset_info.length_unit, - supported_devices=self.__supported_devices__, - dtype=dtype_to_str(dtype), - ) + capabilities = self._get_capabilities() metadata = merge_metadata(self.metadata, metadata) - return AtomisticModel(jit.compile(self.eval()), metadata, capabilities) + if self.hypers["mace_model"] is not None: + to_export = copy.copy(self.eval().to(device="cpu")) + to_export.mace_model = copy.deepcopy(to_export.mace_model) + else: + to_export = self.eval() + + model = AtomisticModel(jit.compile(to_export), metadata, capabilities) + + return model def _add_output(self, target_name: str, target_info: TargetInfo) -> None: """ @@ -619,6 +635,8 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None: self.layouts[target_name] = target_info.layout + self.last_layer_feature_size = 128 + if target_name == self.hypers["mace_head_target"]: # Fake head that will not compute the target, but will help # us extract the last layer features from MACE internal head. diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 46419c67fd..3f057ec207 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -389,12 +389,20 @@ def forward( # compute PRs # the code is the same for PR and LPR - one_over_pr_values = torch.einsum( - "ij, jk, ik -> i", - ll_features.block().values, - self._get_inv_covariance(uncertainty_name), - ll_features.block().values, - ).unsqueeze(1) + if ll_features.block().values.ndim == 3: + one_over_pr_values = torch.einsum( + "icj, jk, ick -> i", + ll_features.block().values, + self._get_inv_covariance(uncertainty_name), + ll_features.block().values, + ).unsqueeze(1) + else: + one_over_pr_values = torch.einsum( + "ij, jk, ik -> i", + ll_features.block().values, + self._get_inv_covariance(uncertainty_name), + ll_features.block().values, + ).unsqueeze(1) original_name = self._get_original_name(uncertainty_name) @@ -417,7 +425,7 @@ def forward( TensorBlock( values=torch.sqrt(one_over_pr_values.expand((-1, num_prop))), samples=ll_features.block().samples, - components=ll_features.block().components, + components=[], properties=cur_prop, ) ], @@ -438,7 +446,7 @@ def forward( one_over_pr_values.shape[0], num_prop ), samples=ll_features.block().samples, - components=ll_features.block().components, + components=[], properties=cur_prop, ) ], @@ -617,6 +625,9 @@ def compute_covariance( else: # For per-atom targets, use the features directly ll_feats = ll_feat_tmap.block().values.detach() + if ll_feats.ndim > 2: + # flatten component dimensions into samples + ll_feats = ll_feats.reshape(-1, ll_feats.shape[-1]) uncertainty_name = _get_uncertainty_name(name) covariance = self._get_covariance(uncertainty_name) covariance += ll_feats.T @ ll_feats @@ -734,19 +745,19 @@ def calibrate( # compute the uncertainty multiplier residuals = pred - targ - abs_residuals = torch.abs(residuals) - if abs_residuals.ndim > 2: + squared_residuals = residuals**2 + if squared_residuals.ndim > 2: # squared residuals need to be summed over component dimensions, # i.e., all but the first and last dimensions - abs_residuals = torch.sum( - abs_residuals, - dim=tuple(range(1, abs_residuals.ndim - 1)), + squared_residuals = torch.sum( + squared_residuals, + dim=tuple(range(1, squared_residuals.ndim - 1)), ) if use_absolute_residuals: - ratios = abs_residuals / unc # can be multi-dimensional + ratios = torch.sqrt(squared_residuals) / unc else: - ratios = (residuals**2) / (unc**2) + ratios = squared_residuals / (unc**2) ratios_sum64 = torch.sum(ratios.to(torch.float64), dim=0) count = torch.tensor( diff --git a/src/metatrain/utils/testing/output.py b/src/metatrain/utils/testing/output.py index 929d2e3d15..4e2b6e37b9 100644 --- a/src/metatrain/utils/testing/output.py +++ b/src/metatrain/utils/testing/output.py @@ -518,6 +518,11 @@ def test_output_last_layer_features( model = self.model_cls(model_hypers, dataset_info) + assert hasattr(model, "last_layer_feature_size"), ( + f"{self.architecture} does not have the attribute " + "`last_layer_feature_size`." + ) + system = System( types=torch.tensor([6, 1, 8, 7]), positions=torch.tensor(