From 924273dad364b80a7aa5d034318bb61ccf885d9e Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 12 Feb 2026 12:37:47 +0100 Subject: [PATCH 1/3] Use Cholesky decomposition for a more stable LLPR --- src/metatrain/llpr/checkpoints.py | 47 +++++++ src/metatrain/llpr/model.py | 128 ++++++++++-------- .../checkpoints/model-v4_trainer-v5.ckpt.gz | Bin 0 -> 8521 bytes src/metatrain/llpr/tests/test_llpr.py | 2 + src/metatrain/llpr/trainer.py | 9 +- 5 files changed, 123 insertions(+), 63 deletions(-) create mode 100644 src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz diff --git a/src/metatrain/llpr/checkpoints.py b/src/metatrain/llpr/checkpoints.py index 1771815ed6..574ebe63c0 100644 --- a/src/metatrain/llpr/checkpoints.py +++ b/src/metatrain/llpr/checkpoints.py @@ -1,3 +1,6 @@ +import torch + + def model_update_v1_v2(checkpoint: dict) -> None: """ Update a v1 checkpoint to v2. @@ -52,6 +55,50 @@ def model_update_v2_v3(checkpoint: dict) -> None: checkpoint["best_optimizer_state_dict"] = None +def model_update_v3_v4(checkpoint: dict) -> None: + """ + Update a v3 checkpoint to v4. + + :param checkpoint: The checkpoint to update. + """ + # need to change all inv_covariance to cholesky buffers + state_dict = checkpoint["model_state_dict"] + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("inv_covariance_"): + cholesky_key = key.replace("inv_covariance_", "cholesky_") + covariance_key = key.replace("inv_covariance_", "covariance_") + covariance = state_dict[covariance_key] + # Try with an increasingly high regularization parameter until + # the matrix is invertible + is_not_pd = True + regularizer = 1e-20 + while is_not_pd and regularizer < 1e16: + try: + cholesky = torch.linalg.cholesky( + covariance + + covariance.T + + regularizer + * torch.eye( + covariance.shape[0], + device=covariance.device, + dtype=torch.float64, + ) + ).to(covariance.dtype) + is_not_pd = False + except RuntimeError: + regularizer *= 10.0 + if is_not_pd: + raise RuntimeError( + "Could not compute Cholesky decomposition. Something went " + "wrong. Please contact the metatrain developers" + ) + new_state_dict[cholesky_key] = cholesky + else: + new_state_dict[key] = value + checkpoint["model_state_dict"] = new_state_dict + + def trainer_update_v1_v2(checkpoint: dict) -> None: """ Update trainer checkpoint from version 1 to version 2. diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 34c6ff36da..041cf5cec1 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -2,7 +2,6 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Union import metatensor.torch as mts -import numpy as np import torch from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( @@ -39,7 +38,7 @@ class LLPRUncertaintyModel(ModelInterface[ModelHypers]): - __checkpoint_version__ = 3 + __checkpoint_version__ = 4 # all torch devices and dtypes are supported, if they are supported by the wrapped # the check is performed in the trainer @@ -163,7 +162,7 @@ def set_wrapped_model(self, model: ModelInterface) -> None: ), ) self.register_buffer( - f"inv_covariance_{uncertainty_name}", + f"cholesky_{uncertainty_name}", torch.zeros( (self.ll_feat_size, self.ll_feat_size), dtype=dtype, @@ -406,14 +405,15 @@ def forward( ll_features_values.shape[0], -1, ll_features_values.shape[-1] ) - # compute PRs - # the code is the same for PR and LPR - one_over_pr_values = torch.einsum( - "icj, jk, ick -> ic", - ll_features_values, - self._get_inv_covariance(uncertainty_name), - ll_features_values, - ).unsqueeze(-1) + # compute PRs; the code is the same for PR and LPR + v = torch.linalg.solve_triangular( + self._get_cholesky(uncertainty_name), + ll_features_values.reshape(-1, ll_features_values.shape[-1]).T, + upper=False, + ) + one_over_pr_values = torch.sum(v**2, dim=0).reshape( + ll_features_values.shape[0], ll_features_values.shape[1], 1 + ) original_name = self._get_original_name(uncertainty_name) number_of_components = _prod( @@ -660,48 +660,62 @@ def compute_covariance( covariance = self._get_covariance(uncertainty_name) torch.distributed.all_reduce(covariance) - def compute_inverse_covariance(self, regularizer: Optional[float] = None) -> None: - """A function to compute the inverse covariance matrix. + def compute_cholesky_decomposition( + self, regularizer: Optional[float] = None + ) -> None: + """A function to compute the Cholesky decomposition of the covariance matrix. - The inverse covariance is stored as a buffer in the model. + The Cholesky decomposition is stored as a buffer in the model. :param regularizer: A regularization parameter to ensure the matrix is - invertible. If not provided, the function will try to compute the - inverse without regularization and increase the regularization - parameter until the matrix is invertible. + positive-definite. If not provided, the function will try to compute the + Cholesky decomposition without regularization and increase the + regularization parameter until the matrix is positive-definite. """ for name in self.outputs_list: uncertainty_name = _get_uncertainty_name(name) - covariance = self._get_covariance(uncertainty_name) - inv_covariance = self._get_inv_covariance(uncertainty_name) + covariance = self._get_covariance(uncertainty_name).to(dtype=torch.float64) + cholesky = self._get_cholesky(uncertainty_name) if regularizer is not None: - inv_covariance[:] = torch.inverse( + cholesky[:] = torch.linalg.cholesky( covariance + + covariance.T + regularizer - * torch.eye(self.ll_feat_size, device=covariance.device) - ) + * torch.eye( + self.ll_feat_size, device=covariance.device, dtype=torch.float64 + ) + ).to(cholesky.dtype) else: # Try with an increasingly high regularization parameter until # the matrix is invertible - def is_psd(x: torch.Tensor) -> torch.Tensor: - return torch.all(torch.linalg.eigvalsh(x) >= 0.0) - - for log10_sigma_squared in torch.linspace(-20.0, 16.0, 33): - if not is_psd( - covariance - + 10**log10_sigma_squared - * torch.eye(self.ll_feat_size, device=covariance.device) - ): - continue - else: - inverse = torch.inverse( + is_not_pd = True + regularizer = 1e-20 + while is_not_pd and regularizer < 1e16: + try: + cholesky[:] = torch.linalg.cholesky( covariance - + 10 ** (log10_sigma_squared + 2.0) # for good conditioning - * torch.eye(self.ll_feat_size, device=covariance.device) - ) - inv_covariance[:] = (inverse + inverse.T) / 2.0 - break + + covariance.T + + regularizer + * torch.eye( + self.ll_feat_size, + device=covariance.device, + dtype=torch.float64, + ) + ).to(cholesky.dtype) + is_not_pd = False + except RuntimeError: + regularizer *= 10.0 + if is_not_pd: + raise RuntimeError( + "Could not compute Cholesky decomposition. Something went " + "wrong. Please contact the metatrain developers" + ) + else: + logging.info( + f"Used regularization parameter of {regularizer:.1e} to " + "compute the Cholesky decomposition" + ) def calibrate( self, @@ -820,27 +834,27 @@ def generate_ensemble(self) -> None: for name, weights in weight_tensors.items(): uncertainty_name = _get_uncertainty_name(name) cur_multiplier = self._get_multiplier(uncertainty_name) - cur_inv_covariance = ( - self._get_inv_covariance(uncertainty_name) - .clone() - .detach() - .cpu() - .numpy() - ) - rng = np.random.default_rng(42) + cur_cholesky = self._get_cholesky(uncertainty_name) ensemble_weights = [] for ii in range(weights.shape[0]): - cur_ensemble_weights = rng.multivariate_normal( - weights[ii].clone().detach().cpu().numpy(), - cur_inv_covariance * cur_multiplier.item() ** 2, - size=self.ensemble_weight_sizes[name], - method="svd", - ).T - cur_ensemble_weights = torch.tensor( - cur_ensemble_weights, device=device, dtype=dtype + z = torch.randn( + (self.ll_feat_size, self.ensemble_weight_sizes[name]), + device=device, + dtype=dtype, + ) + # using the Cholesky decomposition to sample from the multivariate + # normal distribution + ensemble_displacements = ( + torch.linalg.solve_triangular( + cur_cholesky.T, + z, + upper=True, + ) + * cur_multiplier.item() ) + cur_ensemble_weights = weights[ii].unsqueeze(1) + ensemble_displacements ensemble_weights.append(cur_ensemble_weights) ensemble_weights = torch.stack( @@ -972,8 +986,8 @@ def _get_covariance(self, name: str) -> torch.Tensor: raise ValueError(f"Covariance for {name} not found.") return requested_buffer - def _get_inv_covariance(self, name: str) -> torch.Tensor: - name = "inv_covariance_" + name + def _get_cholesky(self, name: str) -> torch.Tensor: + name = "cholesky_" + name requested_buffer = torch.tensor(0) for n, buffer in self.named_buffers(): if n == name: diff --git a/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz b/src/metatrain/llpr/tests/checkpoints/model-v4_trainer-v5.ckpt.gz new file mode 100644 index 0000000000000000000000000000000000000000..47b094c24a8675e8853f713080c6ef2569329982 GIT binary patch literal 8521 zcmX|_bx@Sw`~K-Jkq+sWPU#kwl`UPJ{6oiYLjjc1cj}XwyIcQ2%q zJZwJ_rKMrEnE_JS#0XRRf@F7%-H}n!c#0{cO(gMfF?Bof?6A>adKy~vmL#+E)x!b6 zs_2{cwr}j5eGfgSrX=&;bJsoQPnB!1irRy3HsVs9(90mGbZvI>plYr5uo5)G zQ%mklrgpkUwyG6o8_GgNjQU}-!Cj9N|DFXb)x)Y|xD5GW_O$5WqL4v+O24^29HwnV zjH*K%*HaEQR`j>fX<;Fbv9`G;UM1sfH}nb zJ?k?dY%9G{jvtwYn7vt*mJFi0phe*9pi40I^%BF#nWTXNKf1|*lE*L))CL;-7>P<*mPZW(@kG<_8%PKgf9jMC zz`XSI(>pjuc?Q~*b&ne)Vadir+&DYEx4L`(6K^AqpyzP=F1=4_lL?e{UQyuY>hmzX~XjXD^~D|886|Iu~>QPt3V(PEraFL=-O=PdQU>;Bty z_hDe%n#E7qi!Ii7fPxn;4mJ*!yEr2v+QW|}XXdp7>^f1X%TxQrZge_2VjF5`k1u4D zty^TBw`D_ZK4oD&6gwQx&o(GZGH>!VC@Ps2zGH;X8|lbz$zP#R@&d(u=*}*Kjhs~? zndI^nse~hgxod;Gv{kY7q~u;dX)uZ#wJ$lDvV4pQdQGQsmU{GSQ}$&y${pU%FXet4 zHaQW1#LCuW-V+C%U<-+_UBX1sN{rh$^qG2$2o)M;h^o75mPM^DtN)Z9w#FrWr>|4X z&#H3{5KX$Pav}x=iS{&IdcO{rMLY<3w>Wi#fmsXp`*KrxER?$Ub-L zostE#s*%AofyNMElR~~I*Rdp6xYe|}Do3Rzb^UgL%@et-4KVBN$z668v%e;Zi;fYl z$=1HK)l=b_|4fh>ClyX=%1hoU%i~iRLFPY|LZsj+!1#G9r9N7av`ehUQdB&aB$thW zkUQ)9F!B9|3bTCyY}ymKVZULnnXrKLFOMwb?dCjdJqCE9pP9EcxZ3e()c~*93wN5f z*P6z-wmIq8GJAC7hXa#&4`@>lna#n~Cfl@{nze&)KQlbey}3pFyJc+!Rd&J)DINej z5%J+eq_ev|dbgV?yiQ%BqG2rw^4WBoqlDqzp@+oS$%Iquz})79P+MFn#?8?0@(KpU z6$KXe-=hhY@dq0sO(w6lQB^tS7oz8X

(}G@KFWN8YoaoM23UmDZtmEF^oOMCOml zM+an56>Y1|+1RYp5E*_)VieoA@Zm`;nt0oRuEb-OD~#ELSy=FnML_4zA-if^>h?DB*k~!c#7a>b zZNi2fh`DQxkfHkXdq|`ED5D2O%_rnRt)dSpTCMUY8P)icruBF=QR)W%sSD~iciq=qppJ!I|3moDxSftw`PtEeu#?tgC3GGjX>HBI;p*Wk~IFZ7? z{=!hve!>;9TGgLVHbg(I7|B%St8@0|f<)ckZrpcYcB6Zl!~qx+hZ(1crlhdFOoEea z1ccZX?GL$K96qQzs5m%8tM`F)%GQapeYEyY@g}hTs9ygP@iI_V`9_4O#$&#SFi4yr z7=5|;rXuFpOk5?8jgV@c3W30aVF43Dl|0a3OG51X#0!%)%GFQr0b?K~G2J{y)z#&%Wp9iH0hDR>FBwA3;%D+=$AAJFzLIk zknH==2%kXhsMaxN(xR1tETN=af~x|&V3lk$x>E3~l&~bgK0jf=%F!>B?^`M2(IrKJ zjCtFNT4UeKNTz45Ac6u1#@JGOa+85&6b!;|H#vrI5AetX-r041=d&mUTcl9d0rrIn z1Ejc6W7#t0W6zbv2NKqp@v0)$Qcpcq5MA89GZF%O=GfhPa;Jf8lsmHEpkz%Bg+SY} zC|Sq1bbnO-jz&BY-*;jL<87(>ps0rHu_IKX(u%dVckkK8+q8C)CT*tni0?BogE#V= zDbN~%vJ)$0fumV<=@xg9@OuO0AhyC&seQ~-F>ZN0kre;X>e=vZm<(xe(gBNMQT}$| zME(wvbIu>$YZ0p6I*$WH#+2Kut4?2MNXd+eHEuiBOk7>R{}n}zL`$lh$902j&$Jfx z9U2}iZz%t` zRnaR(QUTMBBo`yU!AQ35Oq^sHHvEp-aVCaAQ#a(BObqSA6G66oiEN<$p6Pww+zL@* zDU8-#5aP=so<^-@6)b8jXS{q)f-an1Sc*JGP_7e8wT6vj<*_eoyD%!tL%MAsYm28x zP=JG^uQSDtA8po@HV|Gp-7xmC`PUCszkpb!mmU~uBiZqyxPFach6?^n;#+byo?cBj zL0ECs1_{U7EI-=gy(p3oLxt+%MvOQoAPt=78}Z&~po+Q15i-6t`E_oTwG?y3((Mf^ z+8P|X!HcQl-3?-vA0_ckcq(sE4=4as7`_j3-sPLAQ&n4fTh_)PZbLCIp`Tm}=sq;o zHE8qay%+Du%oA*=yv{kfM^1V}Fr3!qkTGxslCcVuk%UWXZLn|*(`78yIijm?=YIlFh?$@^w?Da=))GxiqE*W5s1fW!+Yk%mBP1xBKr~Q>+ zO6e>`w~RryjE%Q|(fbnFPsgxBHivc) zXhU>o)AVAKOG>{ll6x-9eYWPDrS>{@3SzQ`{)`F`#JwgPLU5h=A|VXM@@>)xQxFg@Wiwo>l zj}ZiWLs|kWx|}da^c%bZDcrU_q?@41!fP=W*??j)wTSnXdJDYe$Jhd$cwnsK&*P)H z7pZt`(B<=SoF`Ic5b1s$@{>0pKRJZ}c-w>G9wS6y$2GHk=T_s*t~7CoT1>l^ zM2Kl&&p`Pp%C%HMk5+|OR1AR2rbzt*gwX56B<$vg&y%(bB#C55dT4<#9*nxRHJg{v zuT7;3H;ug}G?y^g)Y|r`ra!sWDxu|4kqD`rlD2M|4ExJl5S&8wg~X(E=s^eKr+>>e z?+2_(xJd}V&nWCcoxKUY4=b{K;G-cevC(73CU(IS-Jlzd5qW%hm*+enqs!GnyRUZ8 zdP_c5RPqGqNsEYV*#PM>bv#r!^xTTlkN3xQ81;e`M#*&c<~ZU)zV4t+H0n@uxe1DG z{UFD+bll>8>(wkfB-LBUcFI?IL@EqJ=jry(#8z$%cBgZn*vm676&Z2|vRBXnQ`R!mozw`$E;P)i|hO(>c@41?&0btVI%EPemS^J<}`WxE}Cz4>Tp8k>t z!za8VPLww5{dN4RrC}*;@O72589@eVm+gBt3H&x<=PJAf@sGj_nxd9R{erzdtZsp! zt|f`~VROzd$4l3ewaA@sy39c*CNy*?%FSKxLYAfp>2SzAn^((BX*LIX>ou2yDGgHm zUS&-{EKn%PXKvm3r2C^AUP3;dqc9EebP>gqQ}X+4G4=U1w3N&al1bydmR!o1<{myH z<}<5wslYoD>B3V_nm6Fzx*(-P9kO5=(hPb{!7FNY+|Si}E5zeL!bzq|r^<=qL&@FT5g&mE$PBc65LwP>l2o0+2g14kk?uycxEE1_{U@wNZTh+|( z)sZ#up_id3p>O1nQ1p|V-m+t*0ZcdDhX!>^%uP}j4DZAGRWCB^JODy2i_;A?9&H!+(g)bZ*B~B_r>g11Fzz@>cJI(++gvu%7bJooe3%Nb2H=mA;a>O z6Pz2sP&ZDSA?giP)eq8Xd~9mq_)FV6DqtCLonf>n-)4er7k0ELHG>nHOAxMB+Mx5J z;RnDPLte)x@6($2d(nb5jM?Ryv9*l%{?yPuL0d;^sG}|8q&XnXS8z>peHFDUkfW~C zN_1_I$+nOhI&6Rb^!gVjg^#GE1X5>|RVQ2U7N-Bc;n+mpQVt>4N9q31B$P9LBx%jC z$O!jGw8<1D_eW;ji55y3CDnpa$=^nkNJqw&Qg4lwZO-@JOgrflg5>ZIl~~F%DfVrV zx={)YYlg1QAbyFQ5n$Jy*>=%t4sV#oPynux8|o2l)eDA4^Kaid*H0^6bk`|fAr~K` zKY#DsjX0?zLx(-}4gzKJKY{Gu-11vLp&I2Q9pxJ^P2ZGBs?K%E=rY}##q0y)FVMAn z(8dhskH6kPBo#v?S&9+d&4{VD<=&}Jrx?OlO$|4%hA%nKI7DH9nYTR&nhmTv>;Pl&M>rwmYp3r^F zXm~iX`2*6cBAg{uG*~L-h!y2e*-yxNNCol?RRvNJ!{va#+F`a3K9-CPJbn}+CS-oV zNRbWQ2&x`-U{f#0xOK$rXMqg4SR?@g`mN8k%*UVn!Yj?`e!XPn>sM4}$B@S+V#QW& zIvSN!C9NsLT2YcoPbuY3rBx}B^}t$G`oK2%Zq6xliNm!-cZ{N#hJv)WoFu+{PG3?f zy`t*Qp|sJ#j&E9>{x{0BGef7k`GBWha6lr}Zw7!+bNokkzFm|pzP@N_^I=`R_W`fy zeg~2L-iY@?OFj&1k{#hu=iPo&$=I!uu&cuIbuj)T37=%I5hKOB@0k}Q`mq*6Z#{+L zb}5uEP!5dfxygU1{o(RUtdyfGQ*P9W+4UPbL$)U6dkn|M?+A;k-1H13KZ`x5v2zzy%v&9j-|T2bjn^i%B3o&eR2QdPcb&sJ(BfY&Mcy=W&Qn# zx!fOJ#@gvSGbG1kk*hchGRS|u3Qfl@C^_99ob+|1buA^x?K#?H1XmVAB{1FhaLiT4 z_RclRoksVnH2lJJZa-I_#TA^z6`rY~p_*Kj1bz9)cmYJySL{*mGn?RBM)E|rY{6Yr z+UC0R(V5$-n)T^0d2cpA0YTXkw>0Aol?+0?Gy09TL5;3N)T26eeDA?{_8Vo@w?klk7=Z}I3gRVc zR^#JEfnu_(@MgUTqeFtO%8M$k6Syxk{VUO+QE0l1X9Mex(UulyAtV4l^$`uU5E_8L z{D=}-2=SO7U_>7+glf!RnU4)c26t@|dq`=0uWt*zFB$8Rkpj$2K@tYUxHLrj#CTR^ z3IA%TjTS=t3rf_%h62X6aKkoMT5s(;qwNdDdgRIKVY7J|2F2_oj5uaIv#5X6=3_(w zp|w1qF)^Nn)IW7vDx$$DoDmFO+$bOv0|d#BnBM&GReSkN#l%7 zC?NDmATEaDx?As+>Axhj6d+y-;xG`>79F|*l8~S&BG3Npe9}`H zN@%a}?W~=jeH$!|Y@WqQsF&w0zZlQt-2Ymd8Zw@t)n_}Xy$t)px6mJ_xlfCABS2@S z{`x3T_KmxRKQ$6hmA*ndb0_wmvU6T$(5S>(*5gMTSWgbrF2Jjs0V! z4;4O7$^6-F5kk4@>RV~S9BbrkuwP-d& zj-P+ZE>XEgXW8|e$hc|_1t(OPwh4IHS^Tazx|vvh*M@ADfz?iz6v@yY?e-?u7b>>c z@M|Z^E<=6Qq$THQ;-xV1tfEwLRV$}%zDag=c4GE-8Ht&WO^nvsc!%$ zY|Gq2&l->FYg@8FZw(D)G{;u%1|_cO$o_y#meD<%N?s>abje5C__lE@sl(QeoM@FB zj?0ST4GQeIi{z<-uR7y!NKEW5NF5x~e^LY=Xq&s}x3*X0s~Dv+GS(ttnyD5{cA$a`Lju0|L85E01BKSU+EY5b(Qg0*e`|GE z>>nK7H8y!o9UHsBEjTz6XOKqn5L-q@I*FG_s`b^RDK*IDI}-?ct?C&Eei{)a@$DGNV*$MiC z%1)F5aXn^y;RB;O^7Yp_+FVXpaN9n`6;v61tcuIfts z@kxw{);7tR{Z_4)fFA7oI=0MILw|$LDXDDKxO#RWci3Nr%FTgkKta-m&oo!uQQ(wZbhWK>26Tii-Q6+g z>hQ_h3o3XO5(7VQil-{2F3$CCM$aB9*BC#`&EuuJCVIO3)x@-`1NgYv5pyNtWSNmN zRZ-k_=v3zI#!MqH&mWaBfAf8VaPGv^9c!nqZNoF}O;6|Y*K7;*#at)8-6WtOOD89?rU!uMZ=qmjwR#W zBlz&+_5Xjs{2+ZV0Z>UQ^UgzHM!w|WL}d`p7+cNi1<)1YEPW>}GcIOQQ)O;z?$X}= z(Q9R0@vXyc7Qx>GM!Cw@Y33+o=$v*ew_5~41}3t2G<|17j(A@$ouF=aR$?lXc%Qz zh!?0mc;`-y6S(j>jX-$5JbE-f=63^Ev=-G4FI)u|;iTU_IAMjyHS5EhnYk+4_yAzmIPapSTHv%}40PnoYg^&(GW}AB|;#7M#DF z8}1P!ChS{bJ~S4hxIP8aJ$uxicY7XLzrhkEg~<;RNjEg_c|r^|x1c>6p7e;?S9e#R zSnu8>fJTe=P!JjUL2lXGdqVJ?U>_ZgMga`BhgVNdKM)#8pd4WL}dYfyi$N?g=}Qc|0ps4SAlLiCZKZ(Z5rinX{d_(+}=A;D`oYCGQ=x-tD~E zLFii3!&w-r#+hZx9>9~L=J=46%cPKylIN{AAc>pIfyW2<(V$G$w&9KtA_EqBtFZHA zkoE7tuzmlA-5PfKy6_G2jypcZ1v*EZgCObOe7-#9!VvvXyokpnV7+DKzYxG{8@46{ z!I5*$e~~%~fdpHVJ79~zjQ?sx+0#Q_K3sgh7yjfk`2GQ8gXnp^_wiPv6l?O~TvPr2>OZ#V}JUum%CN(yCWnL@@Kkd=2p0xdN z{iy*>BSr#Fy>CwOLt0rpjk zYkSSnJC~vmJtklME&O;VrayPUO92klCO)yq={bcP7!LJq)ZDm)rcXBSfOYSdCIw{+eY~dA@Zb`DY?rdTakCb*d zOR;sHgd`c+49@FPHegJwcm^77d7rw*Lj~*Z&vsmspvNi3o?ERF`KiE_`KH$omsMM@ zOB@NORVx?l$4i_~(*TN&sUi@8(b0sAN+zu5h`qa0 z_j!$(3A5P#)7jYsdt`DsfGrEQbvDBub!`t|pVh8lwnI6v&v>vM6JNM-yw}aOMR(xm z`8-fO<7)72L%i*zsUXxl;8E7{EJ#Ve&va4L_{j*}uRjYtPT`Tdh!&fE$N< zA^J%vAHDsDWpE2N>@OfJuj;lw+z#8GL2O^u4U|0qb4P)h0bLp2U@}3r7$9@lxk(!< zTpixAy>RQBLDxm~9wWAS$WEclssH~gXR>)*-s)9D?$x*ez6DuLUh zizJAZ@JaqN57d=dx$z+QUrSE-{Mq(d1$4ToWB3H-LHN zv4YIu=l6M~PoMrXICA@XD|p-a&&u%tu4MHb4t{Qu#flKP`)Bn-=v>t9>>pUk@izPM zdW)`%#qn0-f7#@yLjI$fBu3JWt=EmqSj7D~6Z~(Vj?km~|4`A9`F&vZ=dc-t8+8-^ zQ%(Ip6|vBx@PDY&rGND25U0f@4|R#p$FV~PCNHG~YySt~($#o+=AOdaylpx_=4$^n ztNOeUqSH?j#wh1UO&=ZM- Date: Thu, 12 Feb 2026 15:40:54 +0100 Subject: [PATCH 2/3] Fix checkpoint update and make the llpr test pass --- src/metatrain/llpr/checkpoints.py | 3 +-- src/metatrain/llpr/tests/test_llpr.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/metatrain/llpr/checkpoints.py b/src/metatrain/llpr/checkpoints.py index 574ebe63c0..5512cbdd59 100644 --- a/src/metatrain/llpr/checkpoints.py +++ b/src/metatrain/llpr/checkpoints.py @@ -76,8 +76,7 @@ def model_update_v3_v4(checkpoint: dict) -> None: while is_not_pd and regularizer < 1e16: try: cholesky = torch.linalg.cholesky( - covariance - + covariance.T + 0.5 * (covariance + covariance.T) + regularizer * torch.eye( covariance.shape[0], diff --git a/src/metatrain/llpr/tests/test_llpr.py b/src/metatrain/llpr/tests/test_llpr.py index 9f49fd0ba1..04428db1f9 100644 --- a/src/metatrain/llpr/tests/test_llpr.py +++ b/src/metatrain/llpr/tests/test_llpr.py @@ -140,8 +140,6 @@ def check_exported_model_predictions( ) # require lower precision for PET-MAD which only has 128 ensemble members required_precision = 3e-2 if ensemble.shape[1] < 1000 else 1e-2 - print(calc._model.module.multiplier_energy_uncertainty) - print(uncertainty/ensemble.std(dim=1, keepdim=True)) assert torch.allclose( uncertainty, ensemble.std(dim=1, keepdim=True), From ed897fdd0e53ead979ddbe01bb64ad1a146fb6cb Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 12 Feb 2026 15:56:12 +0100 Subject: [PATCH 3/3] Fix the same bug in the model --- src/metatrain/llpr/model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/metatrain/llpr/model.py b/src/metatrain/llpr/model.py index 041cf5cec1..d64a5e5ef5 100644 --- a/src/metatrain/llpr/model.py +++ b/src/metatrain/llpr/model.py @@ -679,8 +679,7 @@ def compute_cholesky_decomposition( cholesky = self._get_cholesky(uncertainty_name) if regularizer is not None: cholesky[:] = torch.linalg.cholesky( - covariance - + covariance.T + 0.5 * (covariance + covariance.T) + regularizer * torch.eye( self.ll_feat_size, device=covariance.device, dtype=torch.float64 @@ -694,8 +693,7 @@ def compute_cholesky_decomposition( while is_not_pd and regularizer < 1e16: try: cholesky[:] = torch.linalg.cholesky( - covariance - + covariance.T + 0.5 * (covariance + covariance.T) + regularizer * torch.eye( self.ll_feat_size,