Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions src/metatrain/experimental/mace/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import warnings
from pathlib import Path
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None:
if hasattr(self.mace_model, "atomic_energies_fn"):
self._loaded_atomic_baseline = (
self.mace_model.atomic_energies_fn.atomic_energies.clone()
)
).ravel()

self.mace_model.atomic_energies_fn.atomic_energies[:] = 0.0

Expand Down Expand Up @@ -574,6 +575,22 @@ def load_checkpoint(

return model

def _get_capabilities(self) -> ModelCapabilities:
dtype = next(self.parameters()).dtype

interaction_range = self.hypers["num_interactions"] * self.cutoff

capabilities = ModelCapabilities(
outputs=self.outputs,
atomic_types=self.atomic_types,
interaction_range=interaction_range,
length_unit=self.dataset_info.length_unit,
supported_devices=self.__supported_devices__,
dtype=dtype_to_str(dtype),
)

return capabilities

def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
dtype = next(self.parameters()).dtype
if dtype not in self.__supported_dtypes__:
Expand All @@ -588,20 +605,19 @@ def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel:
# be registered correctly with Pytorch. This function moves them:
self.additive_models[0].weights_to(torch.device("cpu"), torch.float64)

interaction_range = self.hypers["num_interactions"] * self.cutoff

capabilities = ModelCapabilities(
outputs=self.outputs,
atomic_types=self.atomic_types,
interaction_range=interaction_range,
length_unit=self.dataset_info.length_unit,
supported_devices=self.__supported_devices__,
dtype=dtype_to_str(dtype),
)
capabilities = self._get_capabilities()

metadata = merge_metadata(self.metadata, metadata)

return AtomisticModel(jit.compile(self.eval()), metadata, capabilities)
if self.hypers["mace_model"] is not None:
to_export = copy.copy(self.eval().to(device="cpu"))
to_export.mace_model = copy.deepcopy(to_export.mace_model)
else:
to_export = self.eval()

model = AtomisticModel(jit.compile(to_export), metadata, capabilities)

return model

def _add_output(self, target_name: str, target_info: TargetInfo) -> None:
"""
Expand All @@ -619,6 +635,8 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None:

self.layouts[target_name] = target_info.layout

self.last_layer_feature_size = 128

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't work with:

  1. The internal MACE head.
  2. Heads that output spherical harmonics with more than one irrep.

I think Sanggyu is working on a branch to support multiheads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this branch I just made the internal MACE head work because it is what LLPR supported (a single head with scalar features)

if target_name == self.hypers["mace_head_target"]:
# Fake head that will not compute the target, but will help
# us extract the last layer features from MACE internal head.
Expand Down
41 changes: 26 additions & 15 deletions src/metatrain/llpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,20 @@ def forward(

# compute PRs
# the code is the same for PR and LPR
one_over_pr_values = torch.einsum(
"ij, jk, ik -> i",
ll_features.block().values,
self._get_inv_covariance(uncertainty_name),
ll_features.block().values,
).unsqueeze(1)
if ll_features.block().values.ndim == 3:
one_over_pr_values = torch.einsum(
"icj, jk, ick -> i",
ll_features.block().values,
self._get_inv_covariance(uncertainty_name),
ll_features.block().values,
).unsqueeze(1)
else:
one_over_pr_values = torch.einsum(
"ij, jk, ik -> i",
ll_features.block().values,
self._get_inv_covariance(uncertainty_name),
ll_features.block().values,
).unsqueeze(1)

original_name = self._get_original_name(uncertainty_name)

Expand All @@ -417,7 +425,7 @@ def forward(
TensorBlock(
values=torch.sqrt(one_over_pr_values.expand((-1, num_prop))),
samples=ll_features.block().samples,
components=ll_features.block().components,
components=[],
properties=cur_prop,
)
],
Expand All @@ -438,7 +446,7 @@ def forward(
one_over_pr_values.shape[0], num_prop
),
samples=ll_features.block().samples,
components=ll_features.block().components,
components=[],
properties=cur_prop,
)
],
Expand Down Expand Up @@ -617,6 +625,9 @@ def compute_covariance(
else:
# For per-atom targets, use the features directly
ll_feats = ll_feat_tmap.block().values.detach()
if ll_feats.ndim > 2:
# flatten component dimensions into samples
ll_feats = ll_feats.reshape(-1, ll_feats.shape[-1])
uncertainty_name = _get_uncertainty_name(name)
covariance = self._get_covariance(uncertainty_name)
covariance += ll_feats.T @ ll_feats
Expand Down Expand Up @@ -734,19 +745,19 @@ def calibrate(

# compute the uncertainty multiplier
residuals = pred - targ
abs_residuals = torch.abs(residuals)
if abs_residuals.ndim > 2:
squared_residuals = residuals**2
if squared_residuals.ndim > 2:
# squared residuals need to be summed over component dimensions,
# i.e., all but the first and last dimensions
abs_residuals = torch.sum(
abs_residuals,
dim=tuple(range(1, abs_residuals.ndim - 1)),
squared_residuals = torch.sum(
squared_residuals,
dim=tuple(range(1, squared_residuals.ndim - 1)),
)

if use_absolute_residuals:
ratios = abs_residuals / unc # can be multi-dimensional
ratios = torch.sqrt(squared_residuals) / unc
else:
ratios = (residuals**2) / (unc**2)
ratios = squared_residuals / (unc**2)

ratios_sum64 = torch.sum(ratios.to(torch.float64), dim=0)
count = torch.tensor(
Expand Down
5 changes: 5 additions & 0 deletions src/metatrain/utils/testing/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,11 @@ def test_output_last_layer_features(

model = self.model_cls(model_hypers, dataset_info)

assert hasattr(model, "last_layer_feature_size"), (
f"{self.architecture} does not have the attribute "
"`last_layer_feature_size`."
)

system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
Expand Down
Loading