Skip to content

Commit 741ff6a

Browse files
authored
Merge pull request #183 from Jammy2211/feature/jax_inversion
Feature/jax inversion
2 parents 9a316ef + 67a5830 commit 741ff6a

File tree

31 files changed

+945
-647
lines changed

31 files changed

+945
-647
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .operators.contour import Grid2DContour
6666
from .layout.layout import Layout1D
6767
from .layout.layout import Layout2D
68+
from .preloads import Preloads
6869
from .structures.arrays.uniform_1d import Array1D
6970
from .structures.arrays.uniform_2d import Array2D
7071
from .structures.arrays.rgb import Array2DRGB

autoarray/inversion/inversion/abstract.py

Lines changed: 75 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper
1414
from autoarray.inversion.regularization.abstract import AbstractRegularization
1515
from autoarray.inversion.inversion.settings import SettingsInversion
16+
from autoarray.preloads import Preloads
1617
from autoarray.structures.arrays.uniform_2d import Array2D
1718
from autoarray.structures.visibilities import Visibilities
1819

@@ -27,6 +28,7 @@ def __init__(
2728
dataset: Union[Imaging, Interferometer, DatasetInterface],
2829
linear_obj_list: List[LinearObj],
2930
settings: SettingsInversion = SettingsInversion(),
31+
preloads: Preloads = None,
3032
):
3133
"""
3234
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -66,23 +68,14 @@ def __init__(
6668
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
6769
"""
6870

69-
try:
70-
import numba
71-
except ModuleNotFoundError:
72-
raise exc.InversionException(
73-
"Inversion functionality (linear light profiles, pixelized reconstructions) is "
74-
"disabled if numba is not installed.\n\n"
75-
"This is because the run-times without numba are too slow.\n\n"
76-
"Please install numba, which is described at the following web page:\n\n"
77-
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
78-
)
79-
8071
self.dataset = dataset
8172

8273
self.linear_obj_list = linear_obj_list
8374

8475
self.settings = settings
8576

77+
self.preloads = preloads or Preloads()
78+
8679
@property
8780
def data(self):
8881
return self.dataset.data
@@ -156,17 +149,9 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]:
156149
-------
157150
A list of the index range of the parameters of each linear object in the inversion of the input cls type.
158151
"""
159-
index_list = []
160-
161-
pixel_count = 0
162-
163-
for linear_obj in self.linear_obj_list:
164-
if isinstance(linear_obj, cls):
165-
index_list.append([pixel_count, pixel_count + linear_obj.params])
166-
167-
pixel_count += linear_obj.params
168-
169-
return index_list
152+
return inversion_util.param_range_list_from(
153+
cls=cls, linear_obj_list=self.linear_obj_list
154+
)
170155

171156
def cls_list_from(self, cls: Type, cls_filtered: Optional[Type] = None) -> List:
172157
"""
@@ -267,6 +252,22 @@ def no_regularization_index_list(self) -> List[int]:
267252

268253
return no_regularization_index_list
269254

255+
@property
256+
def mapper_indices(self) -> np.ndarray:
257+
258+
if self.preloads.mapper_indices is not None:
259+
return self.preloads.mapper_indices
260+
261+
mapper_indices = []
262+
263+
param_range_list = self.param_range_list_from(cls=AbstractMapper)
264+
265+
for param_range in param_range_list:
266+
267+
mapper_indices += range(param_range[0], param_range[1])
268+
269+
return np.array(mapper_indices)
270+
270271
@property
271272
def mask(self) -> Array2D:
272273
return self.data.mask
@@ -354,19 +355,14 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]:
354355
regularization it is bypassed.
355356
"""
356357

357-
regularization_matrix = self.regularization_matrix
358-
359358
if self.all_linear_obj_have_regularization:
360-
return regularization_matrix
359+
return self.regularization_matrix
361360

362-
regularization_matrix = np.delete(
363-
regularization_matrix, self.no_regularization_index_list, 0
364-
)
365-
regularization_matrix = np.delete(
366-
regularization_matrix, self.no_regularization_index_list, 1
367-
)
361+
# ids of values which are on edge so zero-d and not solved for.
362+
ids_to_keep = self.mapper_indices
368363

369-
return regularization_matrix
364+
# Zero rows and columns in the matrix we want to ignore
365+
return self.regularization_matrix[ids_to_keep][:, ids_to_keep]
370366

371367
@cached_property
372368
def curvature_reg_matrix(self) -> np.ndarray:
@@ -381,55 +377,31 @@ def curvature_reg_matrix(self) -> np.ndarray:
381377
if not self.has(cls=AbstractRegularization):
382378
return self.curvature_matrix
383379

384-
if len(self.regularization_list) == 1:
385-
curvature_matrix = self.curvature_matrix
386-
curvature_matrix += self.regularization_matrix
387-
388-
del self.__dict__["curvature_matrix"]
389-
390-
return curvature_matrix
391-
392-
return np.add(self.curvature_matrix, self.regularization_matrix)
380+
return jnp.add(self.curvature_matrix, self.regularization_matrix)
393381

394382
@cached_property
395-
def curvature_reg_matrix_reduced(self) -> np.ndarray:
383+
def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
396384
"""
397-
The linear system of equations solves for F + regularization_coefficient*H, which is computed below.
385+
The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the
386+
linear algebra system we solve for using D and F above and is given by
387+
equation (12) in https://arxiv.org/pdf/astro-ph/0302587.pdf.
398388
399-
This is the curvature reg matrix for only the mappers, which is necessary for computing the log det
400-
term without the linear light profiles included.
389+
A complete description of regularization is given in the `regularization.py` and `regularization_util.py`
390+
modules.
391+
392+
For multiple mappers, the regularization matrix is computed as the block diagonal of each individual mapper.
393+
The scipy function `block_diag` has an overhead associated with it and if there is only one mapper and
394+
regularization it is bypassed.
401395
"""
396+
402397
if self.all_linear_obj_have_regularization:
403398
return self.curvature_reg_matrix
404399

405-
curvature_reg_matrix = self.curvature_reg_matrix
400+
# ids of values which are on edge so zero-d and not solved for.
401+
ids_to_keep = self.mapper_indices
406402

407-
curvature_reg_matrix = np.delete(
408-
curvature_reg_matrix, self.no_regularization_index_list, 0
409-
)
410-
curvature_reg_matrix = np.delete(
411-
curvature_reg_matrix, self.no_regularization_index_list, 1
412-
)
413-
414-
return curvature_reg_matrix
415-
416-
@property
417-
def mapper_zero_pixel_list(self) -> np.ndarray:
418-
mapper_zero_pixel_list = []
419-
param_range_list = self.param_range_list_from(cls=LinearObj)
420-
for param_range, linear_obj in zip(param_range_list, self.linear_obj_list):
421-
if isinstance(linear_obj, AbstractMapper):
422-
mapping_matrix_for_image_pixels_source_zero = linear_obj.mapping_matrix[
423-
self.settings.image_pixels_source_zero
424-
]
425-
source_pixels_zero = (
426-
np.sum(mapping_matrix_for_image_pixels_source_zero != 0, axis=0)
427-
!= 0
428-
)
429-
mapper_zero_pixel_list.append(
430-
np.where(source_pixels_zero == True)[0] + param_range[0]
431-
)
432-
return mapper_zero_pixel_list
403+
# Zero rows and columns in the matrix we want to ignore
404+
return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]
433405

434406
@cached_property
435407
def reconstruction(self) -> np.ndarray:
@@ -448,51 +420,36 @@ def reconstruction(self) -> np.ndarray:
448420
ZTx := np.dot(Z.T, x)
449421
"""
450422
if self.settings.use_positive_only_solver:
451-
"""
452-
For the new implementation, we now need to take out the cols and rows of
453-
the curvature_reg_matrix that corresponds to the parameters we force to be 0.
454-
Similar for the data vector.
455423

456-
What we actually doing is that we have set the correspoding cols of the Z to be 0.
457-
As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
458-
And the data_vector = ZTx, so the corresponding row is also taken out.
459-
"""
424+
if self.preloads.source_pixel_zeroed_indices is not None:
460425

461-
if (
462-
self.has(cls=AbstractMapper)
463-
and self.settings.force_edge_pixels_to_zeros
464-
):
426+
# ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
427+
ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep
465428

466-
ids_zeros = jnp.array(self.mapper_edge_pixel_list, dtype=int)
429+
# Use advanced indexing to select rows/columns
430+
data_vector = self.data_vector[ids_to_keep]
431+
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][
432+
:, ids_to_keep
433+
]
467434

468-
values_to_solve = jnp.ones(
469-
self.curvature_reg_matrix.shape[0], dtype=bool
435+
# Perform reconstruction via fnnls
436+
reconstruction_partial = (
437+
inversion_util.reconstruction_positive_only_from(
438+
data_vector=data_vector,
439+
curvature_reg_matrix=curvature_reg_matrix,
440+
settings=self.settings,
441+
)
470442
)
471-
values_to_solve = values_to_solve.at[ids_zeros].set(False)
472-
473-
data_vector_input = self.data_vector[values_to_solve]
474443

475-
curvature_reg_matrix_input = self.curvature_reg_matrix[
476-
values_to_solve, :
477-
][:, values_to_solve]
444+
# Allocate full solution array
445+
reconstruction = jnp.zeros(self.data_vector.shape[0])
478446

479-
# Get the values to assign (must be a JAX array)
480-
reconstruction = inversion_util.reconstruction_positive_only_from(
481-
data_vector=data_vector_input,
482-
curvature_reg_matrix=curvature_reg_matrix_input,
483-
settings=self.settings,
447+
# Scatter the partial solution back to the full shape
448+
reconstruction = reconstruction.at[ids_to_keep].set(
449+
reconstruction_partial
484450
)
485451

486-
# Allocate JAX array
487-
solutions = jnp.zeros(self.curvature_reg_matrix.shape[0])
488-
489-
# Get indices where True
490-
indices = jnp.where(values_to_solve)[0]
491-
492-
# Set reconstruction values at those indices
493-
solutions = solutions.at[indices].set(reconstruction)
494-
495-
return solutions
452+
return reconstruction
496453

497454
else:
498455

@@ -522,7 +479,11 @@ def reconstruction_reduced(self) -> np.ndarray:
522479
if self.all_linear_obj_have_regularization:
523480
return self.reconstruction
524481

525-
return np.delete(self.reconstruction, self.no_regularization_index_list, axis=0)
482+
# ids of values which are on edge so zero-d and not solved for.
483+
ids_to_keep = self.mapper_indices
484+
485+
# Zero rows and columns in the matrix we want to ignore
486+
return self.reconstruction[ids_to_keep]
526487

527488
@property
528489
def reconstruction_dict(self) -> Dict[LinearObj, np.ndarray]:
@@ -665,9 +626,9 @@ def regularization_term(self) -> float:
665626
if not self.has(cls=AbstractRegularization):
666627
return 0.0
667628

668-
return np.matmul(
629+
return jnp.matmul(
669630
self.reconstruction_reduced.T,
670-
np.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
631+
jnp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
671632
)
672633

673634
@cached_property
@@ -682,7 +643,9 @@ def log_det_curvature_reg_matrix_term(self) -> float:
682643

683644
try:
684645
return 2.0 * np.sum(
685-
np.log(np.diag(np.linalg.cholesky(self.curvature_reg_matrix_reduced)))
646+
jnp.log(
647+
jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced))
648+
)
686649
)
687650
except np.linalg.LinAlgError as e:
688651
raise exc.InversionException() from e

autoarray/inversion/inversion/factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
1515
from autoarray.inversion.inversion.imaging.w_tilde import InversionImagingWTilde
1616
from autoarray.inversion.inversion.settings import SettingsInversion
17+
from autoarray.preloads import Preloads
1718
from autoarray.structures.arrays.uniform_2d import Array2D
1819

1920

2021
def inversion_from(
2122
dataset: Union[Imaging, Interferometer, DatasetInterface],
2223
linear_obj_list: List[LinearObj],
2324
settings: SettingsInversion = SettingsInversion(),
25+
preloads: Preloads = None,
2426
):
2527
"""
2628
Factory which given an input dataset and list of linear objects, creates an `Inversion`.
@@ -55,6 +57,7 @@ def inversion_from(
5557
dataset=dataset,
5658
linear_obj_list=linear_obj_list,
5759
settings=settings,
60+
preloads=preloads,
5861
)
5962

6063
return inversion_interferometer_from(
@@ -68,6 +71,7 @@ def inversion_imaging_from(
6871
dataset,
6972
linear_obj_list: List[LinearObj],
7073
settings: SettingsInversion = SettingsInversion(),
74+
preloads: Preloads = None,
7175
):
7276
"""
7377
Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`.
@@ -126,6 +130,7 @@ def inversion_imaging_from(
126130
dataset=dataset,
127131
linear_obj_list=linear_obj_list,
128132
settings=settings,
133+
preloads=preloads,
129134
)
130135

131136

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from typing import Dict, List, Optional, Union, Type
2+
from typing import Dict, List, Union, Type
33

44
from autoconf import cached_property
55

@@ -10,6 +10,7 @@
1010
from autoarray.inversion.inversion.abstract import AbstractInversion
1111
from autoarray.inversion.linear_obj.linear_obj import LinearObj
1212
from autoarray.inversion.inversion.settings import SettingsInversion
13+
from autoarray.preloads import Preloads
1314

1415
from autoarray.inversion.inversion.imaging import inversion_imaging_util
1516

@@ -20,6 +21,7 @@ def __init__(
2021
dataset: Union[Imaging, DatasetInterface],
2122
linear_obj_list: List[LinearObj],
2223
settings: SettingsInversion = SettingsInversion(),
24+
preloads: Preloads = None,
2325
):
2426
"""
2527
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -66,6 +68,7 @@ def __init__(
6668
dataset=dataset,
6769
linear_obj_list=linear_obj_list,
6870
settings=settings,
71+
preloads=preloads,
6972
)
7073

7174
@property

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from autoarray.inversion.linear_obj.linear_obj import LinearObj
1010
from autoarray.inversion.pixelization.mappers.abstract import AbstractMapper
1111
from autoarray.inversion.inversion.settings import SettingsInversion
12+
from autoarray.preloads import Preloads
1213
from autoarray.structures.arrays.uniform_2d import Array2D
1314

1415
from autoarray.inversion.inversion import inversion_util
@@ -21,6 +22,7 @@ def __init__(
2122
dataset: Union[Imaging, DatasetInterface],
2223
linear_obj_list: List[LinearObj],
2324
settings: SettingsInversion = SettingsInversion(),
25+
preloads: Preloads = None,
2426
):
2527
"""
2628
Constructs linear equations (via vectors and matrices) which allow for sets of simultaneous linear equations
@@ -46,6 +48,7 @@ def __init__(
4648
dataset=dataset,
4749
linear_obj_list=linear_obj_list,
4850
settings=settings,
51+
preloads=preloads,
4952
)
5053

5154
@property

0 commit comments

Comments
 (0)