Skip to content

Commit 766333b

Browse files
Jammy2211Jammy2211
authored andcommitted
64 bit config in
1 parent 34a60d5 commit 766333b

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

autoarray/inversion/inversion/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def log_det_curvature_reg_matrix_term(self) -> float:
638638
if not self.has(cls=AbstractRegularization):
639639
return 0.0
640640

641-
return 2.0 * np.sum(
641+
return 2.0 * jnp.sum(
642642
jnp.log(jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced)))
643643
)
644644

@@ -659,7 +659,7 @@ def log_det_regularization_matrix_term(self) -> float:
659659
if not self.has(cls=AbstractRegularization):
660660
return 0.0
661661

662-
return 2.0 * np.sum(
662+
return 2.0 * jnp.sum(
663663
jnp.log(jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced)))
664664
)
665665

test_autoarray/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import jax
12
import jax.numpy as jnp
23

3-
44
def pytest_configure():
55
_ = jnp.sum(jnp.array([0.0])) # Force backend init
66

7+
jax.config.update("jax_enable_x64", True)
78

89
import os
910
from os import path

0 commit comments

Comments
 (0)