File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed
autoarray/inversion/inversion Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -638,7 +638,7 @@ def log_det_curvature_reg_matrix_term(self) -> float:
638638 if not self .has (cls = AbstractRegularization ):
639639 return 0.0
640640
641- return 2.0 * np .sum (
641+ return 2.0 * jnp .sum (
642642 jnp .log (jnp .diag (jnp .linalg .cholesky (self .curvature_reg_matrix_reduced )))
643643 )
644644
@@ -659,7 +659,7 @@ def log_det_regularization_matrix_term(self) -> float:
659659 if not self .has (cls = AbstractRegularization ):
660660 return 0.0
661661
662- return 2.0 * np .sum (
662+ return 2.0 * jnp .sum (
663663 jnp .log (jnp .diag (jnp .linalg .cholesky (self .regularization_matrix_reduced )))
664664 )
665665
Original file line number Diff line number Diff line change 1+ import jax
12import jax .numpy as jnp
23
3-
44def pytest_configure ():
55 _ = jnp .sum (jnp .array ([0.0 ])) # Force backend init
66
7+ jax .config .update ("jax_enable_x64" , True )
78
89import os
910from os import path
You can’t perform that action at this time.
0 commit comments