Skip to content

Commit 97abd12

Browse files
Jammy2211Jammy2211
authored andcommitted
fix all inversion unit tests
1 parent 9c70932 commit 97abd12

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

autoarray/inversion/inversion/inversion_util.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,42 @@ def curvature_matrix_with_added_to_diag_from(
6161
curvature_matrix
6262
The curvature matrix which is being constructed in order to solve a linear system of equations.
6363
"""
64-
return curvature_matrix.at[
65-
no_regularization_index_list, no_regularization_index_list
66-
].add(value)
64+
try:
65+
return curvature_matrix.at[
66+
no_regularization_index_list, no_regularization_index_list
67+
].add(value)
68+
except AttributeError:
69+
return curvature_matrix_with_added_to_diag_from_numba(
70+
curvature_matrix=curvature_matrix,
71+
value=value,
72+
no_regularization_index_list=no_regularization_index_list,
73+
)
74+
75+
@numba_util.jit()
76+
def curvature_matrix_with_added_to_diag_from_numba(
77+
curvature_matrix: np.ndarray,
78+
value: float,
79+
no_regularization_index_list: Optional[List] = None,
80+
) -> np.ndarray:
81+
"""
82+
It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion
83+
via `np.linalg.solve` to fail and raise a `LinAlgError`.
84+
85+
In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix`
86+
makes it positive definite, such that the inversion is performed without raising an error.
87+
88+
This function adds this numerical value to the diagonal of the curvature matrix.
89+
90+
Parameters
91+
----------
92+
curvature_matrix
93+
The curvature matrix which is being constructed in order to solve a linear system of equations.
94+
"""
95+
96+
for i in no_regularization_index_list:
97+
curvature_matrix[i, i] += value
98+
99+
return curvature_matrix
67100

68101

69102
def curvature_matrix_mirrored_from(

test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def test__identical_inversion_values_for_two_methods():
272272
== inversion_mapping_matrices.regularization_matrix
273273
).all()
274274

275-
assert inversion_w_tilde.data_vector.array == pytest.approx(
276-
inversion_mapping_matrices.data_vector.array, 1.0e-8
275+
assert inversion_w_tilde.data_vector == pytest.approx(
276+
inversion_mapping_matrices.data_vector, 1.0e-8
277277
)
278278
assert inversion_w_tilde.curvature_matrix == pytest.approx(
279279
inversion_mapping_matrices.curvature_matrix, 1.0e-8
@@ -285,8 +285,8 @@ def test__identical_inversion_values_for_two_methods():
285285
assert inversion_w_tilde.reconstruction == pytest.approx(
286286
inversion_mapping_matrices.reconstruction, abs=1.0e-1
287287
)
288-
assert inversion_w_tilde.mapped_reconstructed_image == pytest.approx(
289-
inversion_mapping_matrices.mapped_reconstructed_image, abs=1.0e-1
288+
assert inversion_w_tilde.mapped_reconstructed_image.array == pytest.approx(
289+
inversion_mapping_matrices.mapped_reconstructed_image.array, abs=1.0e-1
290290
)
291291
assert inversion_w_tilde.mapped_reconstructed_data == pytest.approx(
292292
inversion_mapping_matrices.mapped_reconstructed_data, abs=1.0e-1

0 commit comments

Comments
 (0)