Skip to content

Commit 12b2d72

Browse files
Jammy2211Jammy2211
authored andcommitted
fix last unit test
1 parent 415926f commit 12b2d72

File tree

5 files changed

+80
-83
lines changed

5 files changed

+80
-83
lines changed

autoarray/inversion/inversion/abstract.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def no_regularization_index_list(self) -> List[int]:
254254
return no_regularization_index_list
255255

256256
@property
257-
def mapper_indices(self) -> np.ndarray[]:
257+
def mapper_indices(self) -> np.ndarray:
258258

259259
if self.preloads.mapper_indices is not None:
260260
return self.preloads.mapper_indices
@@ -421,45 +421,31 @@ def reconstruction(self) -> np.ndarray:
421421
ZTx := np.dot(Z.T, x)
422422
"""
423423
if self.settings.use_positive_only_solver:
424-
"""
425-
For the new implementation, we now need to take out the cols and rows of
426-
the curvature_reg_matrix that corresponds to the parameters we force to be 0.
427-
Similar for the data vector.
428-
429-
What we actually doing is that we have set the correspoding cols of the Z to be 0.
430-
As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
431-
And the data_vector = ZTx, so the corresponding row is also taken out.
432-
"""
433-
434-
if (
435-
self.has(cls=AbstractMapper)
436-
and self.settings.force_edge_pixels_to_zeros
437-
):
438-
439-
# ids of values which are on edge so zero-d and not solved for.
440-
ids_to_remove = jnp.array(self.mapper_edge_pixel_list, dtype=int)
441-
442-
# Create a boolean mask: True = keep, False = ignore
443-
mask = (
444-
jnp.ones(self.data_vector.shape[0], dtype=bool)
445-
.at[ids_to_remove]
446-
.set(False)
447-
)
448424

449-
# Zero out entries we don't want to solve for
450-
data_vector_masked = self.data_vector * mask
425+
if self.preloads.source_pixel_zeroed_indices is not None:
426+
427+
# ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
428+
ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep
451429

452-
# Zero rows and columns in the matrix we want to ignore
453-
mask_matrix = mask[:, None] * mask[None, :]
454-
curvature_reg_matrix_masked = self.curvature_reg_matrix * mask_matrix
430+
# Use advanced indexing to select rows/columns
431+
data_vector = self.data_vector[ids_to_keep]
432+
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]
455433

456434
# Perform reconstruction via fnnls
457-
return inversion_util.reconstruction_positive_only_from(
458-
data_vector=data_vector_masked,
459-
curvature_reg_matrix=curvature_reg_matrix_masked,
435+
reconstruction_partial = inversion_util.reconstruction_positive_only_from(
436+
data_vector=data_vector,
437+
curvature_reg_matrix=curvature_reg_matrix,
460438
settings=self.settings,
461439
)
462440

441+
# Allocate full solution array
442+
reconstruction = jnp.zeros(self.data_vector.shape[0])
443+
444+
# Scatter the partial solution back to the full shape
445+
reconstruction = reconstruction.at[ids_to_keep].set(reconstruction_partial)
446+
447+
return reconstruction
448+
463449
else:
464450

465451
return inversion_util.reconstruction_positive_only_from(

autoarray/preloads.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
import jax.numpy as jnp
34
import numpy as np
45

56
logger = logging.getLogger(__name__)
@@ -9,28 +10,47 @@
910

1011
class Preloads:
1112

12-
def __init__(self, mapper_indices: np.ndarray = None):
13+
def __init__(self, mapper_indices: np.ndarray = None, source_pixel_zeroed_indices: np.ndarray = None):
1314
"""
14-
Preload in memory arrays and matrices used to perform pixelized linear inversions, for both key functionality
15-
and speeding up the run-time of the inversion.
15+
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
16+
and compatibility with JAX.
1617
17-
Certain preloading arrays (e.g. `mapper_indices`) are stored here because JAX requires that they are
18-
known and defined as static arrays before sampling. During each inversion, the preloads will be inspected
19-
for these fixed arrays and used to change matrix shapes in an identical way for every likelihood evaluation.
20-
21-
Other preloading arrays are used purely to speed up the run-time of the inversion, such as
22-
the `curvature_matrix_preload` array. For certain models (e.g. if the source model is fixed and only the
23-
lens light is being fitted for), certain quadrants of the `curvature_matrix` are fixed
24-
for every likelihood evaluation, meaning that they can be preloaded and used to speed up the inversion.
18+
Some arrays (e.g. `mapper_indices`) are required to be defined before sampling begins, because JAX demands
19+
that input shapes remain static. These are used during each inversion to ensure consistent matrix shapes
20+
for all likelihood evaluations.
2521
22+
Other arrays (e.g. parts of the curvature matrix) are preloaded purely to improve performance. In cases where
23+
the source model is fixed (e.g. when fitting only the lens light), sections of the curvature matrix do not
24+
change and can be reused, avoiding redundant computation.
2625
2726
Parameters
2827
----------
2928
mapper_indices
30-
The integer indexes of the mapper pixels in a pixeized inversion, which separate their indexes from those
31-
of linear light profiles in the inversion. This is used to extract `_reduced`
32-
matrices (e.g. `curvature_matrix_reduced`) to compute the `log_evidence` terms of the pixelized inversion
33-
likelihood function.
29+
The integer indices of mapper pixels in the inversion. Used to extract reduced matrices (e.g.
30+
`curvature_matrix_reduced`) that compute the pixelized inversion's log evidence term, where the indicies
31+
are requirred to separate the rows and columns of matrices from linear light profiles.
32+
source_pixel_zeroed_indices
33+
Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to
34+
outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping
35+
separate the lens light from the pixelized source model.
3436
"""
3537

36-
self.mapper_indices = mapper_indices
38+
self.mapper_indices = None
39+
self.source_pixel_zeroed_indices = None
40+
self.source_pixel_zeroed_indices_to_keep = None
41+
42+
if mapper_indices is not None:
43+
44+
self.mapper_indices = jnp.array(mapper_indices)
45+
46+
if source_pixel_zeroed_indices is not None:
47+
48+
self.source_pixel_zeroed_indices = jnp.array(source_pixel_zeroed_indices)
49+
50+
ids_zeros = jnp.array(source_pixel_zeroed_indices, dtype=int)
51+
52+
values_to_solve = jnp.ones(np.max(mapper_indices), dtype=bool)
53+
values_to_solve = values_to_solve.at[ids_zeros].set(False)
54+
55+
# Get the indices where values_to_solve is True
56+
self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0]

test_autoarray/inversion/inversion/test_abstract.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -257,33 +257,6 @@ def test__curvature_reg_matrix_reduced():
257257
).all()
258258

259259

260-
# def test__curvature_reg_matrix_solver__edge_pixels_set_to_zero():
261-
#
262-
# curvature_reg_matrix = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
263-
#
264-
# linear_obj_list = [
265-
# aa.m.MockMapper(parameters=3, regularization=None, edge_pixel_list=[0])
266-
# ]
267-
#
268-
# inversion = aa.m.MockInversion(
269-
# linear_obj_list=linear_obj_list,
270-
# curvature_reg_matrix=curvature_reg_matrix,
271-
# settings=aa.SettingsInversion(force_edge_pixels_to_zeros=True),
272-
# )
273-
#
274-
# curvature_reg_matrix = np.array(
275-
# [
276-
# [0.0, 2.0, 3.0],
277-
# [0.0, 5.0, 6.0],
278-
# [0.0, 8.0, 9.0],
279-
# ]
280-
# )
281-
#
282-
# assert inversion.curvature_reg_matrix_solver == pytest.approx(
283-
# curvature_reg_matrix, 1.0e-4
284-
# )
285-
286-
287260
def test__regularization_matrix():
288261
reg_0 = aa.m.MockRegularization(regularization_matrix=np.ones((2, 2)))
289262
reg_1 = aa.m.MockRegularization(regularization_matrix=2.0 * np.ones((3, 3)))

test_autoarray/inversion/inversion/test_factory.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,25 @@ def test__inversion_imaging__via_regularizations(
189189
assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4)
190190

191191

192+
def test__inversion_imaging__source_pixel_zeroed_indices(
193+
masked_imaging_7x7_no_blur,
194+
rectangular_mapper_7x7_3x3,
195+
):
196+
inversion = aa.Inversion(
197+
dataset=masked_imaging_7x7_no_blur,
198+
linear_obj_list=[rectangular_mapper_7x7_3x3],
199+
settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True),
200+
preloads=aa.Preloads(
201+
mapper_indices=range(0, 9),
202+
source_pixel_zeroed_indices=np.array([0])
203+
)
204+
)
205+
206+
assert inversion.reconstruction.shape[0] == 9
207+
assert inversion.reconstruction[0] == 0.0
208+
assert inversion.reconstruction[1] > 0.0
209+
210+
192211
def test__inversion_imaging__via_linear_obj_func_and_mapper(
193212
masked_imaging_7x7_no_blur,
194213
rectangular_mapper_7x7_3x3,
@@ -557,19 +576,19 @@ def test__inversion_matrices__x2_mappers(
557576

558577
assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][
559578
4
560-
] == pytest.approx(0.004607102, 1.0e-4)
579+
] == pytest.approx( 0.5000029374603968, 1.0e-4)
561580
assert inversion.reconstruction_dict[delaunay_mapper_9_3x3][4] == pytest.approx(
562-
0.0475967358, 1.0e-4
581+
0.4999970390886761, 1.0e-4
563582
)
564-
assert inversion.reconstruction[13] == pytest.approx(0.047596735850, 1.0e-4)
583+
assert inversion.reconstruction[13] == pytest.approx(0.49999703908867, 1.0e-4)
565584

566585
assert inversion.mapped_reconstructed_data_dict[rectangular_mapper_7x7_3x3][
567586
4
568-
] == pytest.approx(0.0022574, 1.0e-4)
587+
] == pytest.approx(0.5000029, 1.0e-4)
569588
assert inversion.mapped_reconstructed_data_dict[delaunay_mapper_9_3x3][
570589
3
571-
] == pytest.approx(0.01545999, 1.0e-4)
572-
assert inversion.mapped_reconstructed_image[4] == pytest.approx(0.05237029, 1.0e-4)
590+
] == pytest.approx(0.49999704, 1.0e-4)
591+
assert inversion.mapped_reconstructed_image[4] == pytest.approx(0.99999998, 1.0e-4)
573592

574593

575594
def test__inversion_imaging__positive_only_solver(masked_imaging_7x7_no_blur):

test_autoarray/inversion/inversion/test_settings_dict.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def make_settings_dict():
1717
"use_positive_only_solver": False,
1818
"positive_only_uses_p_initial": False,
1919
"force_edge_pixels_to_zeros": True,
20-
"image_pixels_source_zero": None,
2120
"no_regularization_add_to_curvature_diag_value": 1e-08,
2221
"use_w_tilde_numpy": False,
2322
"use_source_loop": False,

0 commit comments

Comments
 (0)