Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/metatrain/llpr/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch


def model_update_v1_v2(checkpoint: dict) -> None:
"""
Update a v1 checkpoint to v2.
Expand Down Expand Up @@ -52,6 +55,49 @@ def model_update_v2_v3(checkpoint: dict) -> None:
checkpoint["best_optimizer_state_dict"] = None


def model_update_v3_v4(checkpoint: dict) -> None:
"""
Update a v3 checkpoint to v4.

:param checkpoint: The checkpoint to update.
"""
# need to change all inv_covariance to cholesky buffers
state_dict = checkpoint["model_state_dict"]
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith("inv_covariance_"):
cholesky_key = key.replace("inv_covariance_", "cholesky_")
covariance_key = key.replace("inv_covariance_", "covariance_")
covariance = state_dict[covariance_key]
# Try with an increasingly high regularization parameter until
# the matrix is invertible
is_not_pd = True
regularizer = 1e-20
while is_not_pd and regularizer < 1e16:
try:
cholesky = torch.linalg.cholesky(
0.5 * (covariance + covariance.T)
+ regularizer
* torch.eye(
covariance.shape[0],
device=covariance.device,
dtype=torch.float64,
)
).to(covariance.dtype)
is_not_pd = False
except RuntimeError:
regularizer *= 10.0
if is_not_pd:
raise RuntimeError(
"Could not compute Cholesky decomposition. Something went "
"wrong. Please contact the metatrain developers"
)
new_state_dict[cholesky_key] = cholesky
else:
new_state_dict[key] = value
checkpoint["model_state_dict"] = new_state_dict


def trainer_update_v1_v2(checkpoint: dict) -> None:
"""
Update trainer checkpoint from version 1 to version 2.
Expand Down
130 changes: 71 additions & 59 deletions src/metatrain/llpr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Dict, Iterator, List, Literal, Optional, Union

import metatensor.torch as mts
import numpy as np
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import (
Expand Down Expand Up @@ -39,7 +38,7 @@


class LLPRUncertaintyModel(ModelInterface[ModelHypers]):
__checkpoint_version__ = 3
__checkpoint_version__ = 4

# all torch devices and dtypes are supported, if they are supported by the wrapped
# the check is performed in the trainer
Expand Down Expand Up @@ -163,7 +162,7 @@ def set_wrapped_model(self, model: ModelInterface) -> None:
),
)
self.register_buffer(
f"inv_covariance_{uncertainty_name}",
f"cholesky_{uncertainty_name}",
torch.zeros(
(self.ll_feat_size, self.ll_feat_size),
dtype=dtype,
Expand Down Expand Up @@ -406,14 +405,15 @@ def forward(
ll_features_values.shape[0], -1, ll_features_values.shape[-1]
)

# compute PRs
# the code is the same for PR and LPR
one_over_pr_values = torch.einsum(
"icj, jk, ick -> ic",
ll_features_values,
self._get_inv_covariance(uncertainty_name),
ll_features_values,
).unsqueeze(-1)
# compute PRs; the code is the same for PR and LPR
v = torch.linalg.solve_triangular(
self._get_cholesky(uncertainty_name),
ll_features_values.reshape(-1, ll_features_values.shape[-1]).T,
upper=False,
)
one_over_pr_values = torch.sum(v**2, dim=0).reshape(
ll_features_values.shape[0], ll_features_values.shape[1], 1
)

original_name = self._get_original_name(uncertainty_name)
number_of_components = _prod(
Expand Down Expand Up @@ -660,48 +660,60 @@ def compute_covariance(
covariance = self._get_covariance(uncertainty_name)
torch.distributed.all_reduce(covariance)

def compute_inverse_covariance(self, regularizer: Optional[float] = None) -> None:
"""A function to compute the inverse covariance matrix.
def compute_cholesky_decomposition(
self, regularizer: Optional[float] = None
) -> None:
"""A function to compute the Cholesky decomposition of the covariance matrix.

The inverse covariance is stored as a buffer in the model.
The Cholesky decomposition is stored as a buffer in the model.

:param regularizer: A regularization parameter to ensure the matrix is
invertible. If not provided, the function will try to compute the
inverse without regularization and increase the regularization
parameter until the matrix is invertible.
positive-definite. If not provided, the function will try to compute the
Cholesky decomposition without regularization and increase the
regularization parameter until the matrix is positive-definite.
"""

for name in self.outputs_list:
uncertainty_name = _get_uncertainty_name(name)
covariance = self._get_covariance(uncertainty_name)
inv_covariance = self._get_inv_covariance(uncertainty_name)
covariance = self._get_covariance(uncertainty_name).to(dtype=torch.float64)
cholesky = self._get_cholesky(uncertainty_name)
if regularizer is not None:
inv_covariance[:] = torch.inverse(
covariance
cholesky[:] = torch.linalg.cholesky(
0.5 * (covariance + covariance.T)
+ regularizer
* torch.eye(self.ll_feat_size, device=covariance.device)
)
* torch.eye(
self.ll_feat_size, device=covariance.device, dtype=torch.float64
)
).to(cholesky.dtype)
else:
# Try with an increasingly high regularization parameter until
# the matrix is invertible
def is_psd(x: torch.Tensor) -> torch.Tensor:
return torch.all(torch.linalg.eigvalsh(x) >= 0.0)

for log10_sigma_squared in torch.linspace(-20.0, 16.0, 33):
if not is_psd(
covariance
+ 10**log10_sigma_squared
* torch.eye(self.ll_feat_size, device=covariance.device)
):
continue
else:
inverse = torch.inverse(
covariance
+ 10 ** (log10_sigma_squared + 2.0) # for good conditioning
* torch.eye(self.ll_feat_size, device=covariance.device)
)
inv_covariance[:] = (inverse + inverse.T) / 2.0
break
is_not_pd = True
regularizer = 1e-20
while is_not_pd and regularizer < 1e16:
try:
cholesky[:] = torch.linalg.cholesky(
0.5 * (covariance + covariance.T)
+ regularizer
* torch.eye(
self.ll_feat_size,
device=covariance.device,
dtype=torch.float64,
)
).to(cholesky.dtype)
is_not_pd = False
except RuntimeError:
regularizer *= 10.0
if is_not_pd:
raise RuntimeError(
"Could not compute Cholesky decomposition. Something went "
"wrong. Please contact the metatrain developers"
)
else:
logging.info(
f"Used regularization parameter of {regularizer:.1e} to "
"compute the Cholesky decomposition"
)

def calibrate(
self,
Expand Down Expand Up @@ -820,27 +832,27 @@ def generate_ensemble(self) -> None:
for name, weights in weight_tensors.items():
uncertainty_name = _get_uncertainty_name(name)
cur_multiplier = self._get_multiplier(uncertainty_name)
cur_inv_covariance = (
self._get_inv_covariance(uncertainty_name)
.clone()
.detach()
.cpu()
.numpy()
)
rng = np.random.default_rng(42)
cur_cholesky = self._get_cholesky(uncertainty_name)

ensemble_weights = []

for ii in range(weights.shape[0]):
cur_ensemble_weights = rng.multivariate_normal(
weights[ii].clone().detach().cpu().numpy(),
cur_inv_covariance * cur_multiplier.item() ** 2,
size=self.ensemble_weight_sizes[name],
method="svd",
).T
cur_ensemble_weights = torch.tensor(
cur_ensemble_weights, device=device, dtype=dtype
z = torch.randn(
(self.ll_feat_size, self.ensemble_weight_sizes[name]),
device=device,
dtype=dtype,
)
# using the Cholesky decomposition to sample from the multivariate
# normal distribution
ensemble_displacements = (
torch.linalg.solve_triangular(
cur_cholesky.T,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we avoid transposing and passing upper=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the same as far as I understand

z,
upper=True,
)
* cur_multiplier.item()
)
cur_ensemble_weights = weights[ii].unsqueeze(1) + ensemble_displacements
ensemble_weights.append(cur_ensemble_weights)

ensemble_weights = torch.stack(
Expand Down Expand Up @@ -972,8 +984,8 @@ def _get_covariance(self, name: str) -> torch.Tensor:
raise ValueError(f"Covariance for {name} not found.")
return requested_buffer

def _get_inv_covariance(self, name: str) -> torch.Tensor:
name = "inv_covariance_" + name
def _get_cholesky(self, name: str) -> torch.Tensor:
name = "cholesky_" + name
requested_buffer = torch.tensor(0)
for n, buffer in self.named_buffers():
if n == name:
Expand Down
Binary file not shown.
9 changes: 3 additions & 6 deletions src/metatrain/llpr/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,12 @@ def train(
model.to(device=device, dtype=dtype)

if start_epoch == 0:
logging.info(
"Computing LLPR covariance matrix "
f"using {self.hypers['calibration_method'].upper()}"
)
logging.info("Computing LLPR covariance matrix")
model.compute_covariance(
train_datasets, self.hypers["batch_size"], is_distributed
)
logging.info("Computing LLPR inverse covariance matrix")
model.compute_inverse_covariance(self.hypers["regularizer"])
logging.info("Computing Cholesky decomposition of the covariance matrix")
model.compute_cholesky_decomposition(self.hypers["regularizer"])
logging.info("Calibrating LLPR uncertainties")
model.calibrate(
val_datasets,
Expand Down
Loading