diff --git a/src/metatrain/llpr/checkpoints.py b/src/metatrain/llpr/checkpoints.py index 1771815ed6..5512cbdd59 100644 --- a/src/metatrain/llpr/checkpoints.py +++ b/src/metatrain/llpr/checkpoints.py @@ -1,3 +1,6 @@ +import torch + + def model_update_v1_v2(checkpoint: dict) -> None: """ Update a v1 checkpoint to v2. @@ -52,6 +55,49 @@ def model_update_v2_v3(checkpoint: dict) -> None: checkpoint["best_optimizer_state_dict"] = None +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + # need to change all inv_covariance to cholesky buffers + state_dict = checkpoint["model_state_dict"] + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("inv_covariance_"): + cholesky_key = key.replace("inv_covariance_", "cholesky_") + covariance_key = key.replace("inv_covariance_", "covariance_") + covariance = state_dict[covariance_key] + # Try with an increasingly high regularization parameter until + # the matrix is invertible + is_not_pd = True + regularizer = 1e-20 + while is_not_pd and regularizer < 1e16: + try: + cholesky = torch.linalg.cholesky( + 0.5 * (covariance + covariance.T) + + regularizer + * torch.eye( + covariance.shape[0], + device=covariance.device, + dtype=torch.float64, + ) + ).to(covariance.dtype) + is_not_pd = False + except RuntimeError: + regularizer *= 10.0 + if is_not_pd: + raise RuntimeError( + "Could not compute Cholesky decomposition. Something went " + "wrong. Please contact the metatrain developers" + ) + new_state_dict[cholesky_key] = cholesky + else: + new_state_dict[key] = value + checkpoint["model_state_dict"] = new_state_dict + + def trainer_update_v1_v2(checkpoint: dict) -> None: """ Update trainer checkpoint from version 1 to version 2. diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 34c6ff36da..d64a5e5ef5 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -2,7 +2,6 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Union import metatensor.torch as mts -import numpy as np import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( @@ -39,7 +38,7 @@ class LLPRUncertaintyModel(ModelInterface[ModelHypers]): - __checkpoint_version__ = 3 + __checkpoint_version__ = 4 # all torch devices and dtypes are supported, if they are supported by the wrapped # the check is performed in the trainer @@ -163,7 +162,7 @@ def set_wrapped_model(self, model: ModelInterface) -> None: ), ) self.register_buffer( - f"inv_covariance_{uncertainty_name}", + f"cholesky_{uncertainty_name}", torch.zeros( (self.ll_feat_size, self.ll_feat_size), dtype=dtype, @@ -406,14 +405,15 @@ def forward( ll_features_values.shape[0], -1, ll_features_values.shape[-1] ) - # compute PRs - # the code is the same for PR and LPR - one_over_pr_values = torch.einsum( - "icj, jk, ick -> ic", - ll_features_values, - self._get_inv_covariance(uncertainty_name), - ll_features_values, - ).unsqueeze(-1) + # compute PRs; the code is the same for PR and LPR + v = torch.linalg.solve_triangular( + self._get_cholesky(uncertainty_name), + ll_features_values.reshape(-1, ll_features_values.shape[-1]).T, + upper=False, + ) + one_over_pr_values = torch.sum(v**2, dim=0).reshape( + ll_features_values.shape[0], ll_features_values.shape[1], 1 + ) original_name = self._get_original_name(uncertainty_name) number_of_components = _prod( @@ -660,48 +660,60 @@ def compute_covariance( covariance = self._get_covariance(uncertainty_name) torch.distributed.all_reduce(covariance) - def compute_inverse_covariance(self, regularizer: Optional[float] = None) -> None: - """A function to compute the inverse covariance matrix. + def compute_cholesky_decomposition( + self, regularizer: Optional[float] = None + ) -> None: + """A function to compute the Cholesky decomposition of the covariance matrix. - The inverse covariance is stored as a buffer in the model. + The Cholesky decomposition is stored as a buffer in the model. :param regularizer: A regularization parameter to ensure the matrix is - invertible. If not provided, the function will try to compute the - inverse without regularization and increase the regularization - parameter until the matrix is invertible. + positive-definite. If not provided, the function will try to compute the + Cholesky decomposition without regularization and increase the + regularization parameter until the matrix is positive-definite. """ for name in self.outputs_list: uncertainty_name = _get_uncertainty_name(name) - covariance = self._get_covariance(uncertainty_name) - inv_covariance = self._get_inv_covariance(uncertainty_name) + covariance = self._get_covariance(uncertainty_name).to(dtype=torch.float64) + cholesky = self._get_cholesky(uncertainty_name) if regularizer is not None: - inv_covariance[:] = torch.inverse( - covariance + cholesky[:] = torch.linalg.cholesky( + 0.5 * (covariance + covariance.T) + regularizer - * torch.eye(self.ll_feat_size, device=covariance.device) - ) + * torch.eye( + self.ll_feat_size, device=covariance.device, dtype=torch.float64 + ) + ).to(cholesky.dtype) else: # Try with an increasingly high regularization parameter until # the matrix is invertible - def is_psd(x: torch.Tensor) -> torch.Tensor: - return torch.all(torch.linalg.eigvalsh(x) >= 0.0) - - for log10_sigma_squared in torch.linspace(-20.0, 16.0, 33): - if not is_psd( - covariance - + 10**log10_sigma_squared - * torch.eye(self.ll_feat_size, device=covariance.device) - ): - continue - else: - inverse = torch.inverse( - covariance - + 10 ** (log10_sigma_squared + 2.0) # for good conditioning - * torch.eye(self.ll_feat_size, device=covariance.device) - ) - inv_covariance[:] = (inverse + inverse.T) / 2.0 - break + is_not_pd = True + regularizer = 1e-20 + while is_not_pd and regularizer < 1e16: + try: + cholesky[:] = torch.linalg.cholesky( + 0.5 * (covariance + covariance.T) + + regularizer + * torch.eye( + self.ll_feat_size, + device=covariance.device, + dtype=torch.float64, + ) + ).to(cholesky.dtype) + is_not_pd = False + except RuntimeError: + regularizer *= 10.0 + if is_not_pd: + raise RuntimeError( + "Could not compute Cholesky decomposition. Something went " + "wrong. Please contact the metatrain developers" + ) + else: + logging.info( + f"Used regularization parameter of {regularizer:.1e} to " + "compute the Cholesky decomposition" + ) def calibrate( self, @@ -820,27 +832,27 @@ def generate_ensemble(self) -> None: for name, weights in weight_tensors.items(): uncertainty_name = _get_uncertainty_name(name) cur_multiplier = self._get_multiplier(uncertainty_name) - cur_inv_covariance = ( - self._get_inv_covariance(uncertainty_name) - .clone() - .detach() - .cpu() - .numpy() - ) - rng = np.random.default_rng(42) + cur_cholesky = self._get_cholesky(uncertainty_name) ensemble_weights = [] for ii in range(weights.shape[0]): - cur_ensemble_weights = rng.multivariate_normal( - weights[ii].clone().detach().cpu().numpy(), - cur_inv_covariance * cur_multiplier.item() ** 2, - size=self.ensemble_weight_sizes[name], - method="svd", - ).T - cur_ensemble_weights = torch.tensor( - cur_ensemble_weights, device=device, dtype=dtype + z = torch.randn( + (self.ll_feat_size, self.ensemble_weight_sizes[name]), + device=device, + dtype=dtype, + ) + # using the Cholesky decomposition to sample from the multivariate + # normal distribution + ensemble_displacements = ( + torch.linalg.solve_triangular( + cur_cholesky.T, + z, + upper=True, + ) + * cur_multiplier.item() ) + cur_ensemble_weights = weights[ii].unsqueeze(1) + ensemble_displacements ensemble_weights.append(cur_ensemble_weights) ensemble_weights = torch.stack( @@ -972,8 +984,8 @@ def _get_covariance(self, name: str) -> torch.Tensor: raise ValueError(f"Covariance for {name} not found.") return requested_buffer - def _get_inv_covariance(self, name: str) -> torch.Tensor: - name = "inv_covariance_" + name + def _get_cholesky(self, name: str) -> torch.Tensor: + name = "cholesky_" + name requested_buffer = torch.tensor(0) for n, buffer in self.named_buffers(): if n == name: diff --git a/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz new file mode 100644 index 0000000000..47b094c24a Binary files /dev/null and b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz differ diff --git a/src/metatrain/llpr/trainer.py b/src/metatrain/llpr/trainer.py index a00bf0683b..deda6af7e7 100644 --- a/src/metatrain/llpr/trainer.py +++ b/src/metatrain/llpr/trainer.py @@ -158,15 +158,12 @@ def train( model.to(device=device, dtype=dtype) if start_epoch == 0: - logging.info( - "Computing LLPR covariance matrix " - f"using {self.hypers['calibration_method'].upper()}" - ) + logging.info("Computing LLPR covariance matrix") model.compute_covariance( train_datasets, self.hypers["batch_size"], is_distributed ) - logging.info("Computing LLPR inverse covariance matrix") - model.compute_inverse_covariance(self.hypers["regularizer"]) + logging.info("Computing Cholesky decomposition of the covariance matrix") + model.compute_cholesky_decomposition(self.hypers["regularizer"]) logging.info("Calibrating LLPR uncertainties") model.calibrate( val_datasets,