1313from autoarray .inversion .pixelization .mappers .abstract import AbstractMapper
1414from autoarray .inversion .regularization .abstract import AbstractRegularization
1515from autoarray .inversion .inversion .settings import SettingsInversion
16+ from autoarray .preloads import Preloads
1617from autoarray .structures .arrays .uniform_2d import Array2D
1718from autoarray .structures .visibilities import Visibilities
1819
@@ -27,6 +28,7 @@ def __init__(
2728 dataset : Union [Imaging , Interferometer , DatasetInterface ],
2829 linear_obj_list : List [LinearObj ],
2930 settings : SettingsInversion = SettingsInversion (),
31+ preloads : Preloads = None ,
3032 ):
3133 """
3234 An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -66,23 +68,14 @@ def __init__(
6668 Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
6769 """
6870
69- try :
70- import numba
71- except ModuleNotFoundError :
72- raise exc .InversionException (
73- "Inversion functionality (linear light profiles, pixelized reconstructions) is "
74- "disabled if numba is not installed.\n \n "
75- "This is because the run-times without numba are too slow.\n \n "
76- "Please install numba, which is described at the following web page:\n \n "
77- "https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
78- )
79-
8071 self .dataset = dataset
8172
8273 self .linear_obj_list = linear_obj_list
8374
8475 self .settings = settings
8576
77+ self .preloads = preloads or Preloads ()
78+
8679 @property
8780 def data (self ):
8881 return self .dataset .data
@@ -156,17 +149,9 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]:
156149 -------
157150 A list of the index range of the parameters of each linear object in the inversion of the input cls type.
158151 """
159- index_list = []
160-
161- pixel_count = 0
162-
163- for linear_obj in self .linear_obj_list :
164- if isinstance (linear_obj , cls ):
165- index_list .append ([pixel_count , pixel_count + linear_obj .params ])
166-
167- pixel_count += linear_obj .params
168-
169- return index_list
152+ return inversion_util .param_range_list_from (
153+ cls = cls , linear_obj_list = self .linear_obj_list
154+ )
170155
171156 def cls_list_from (self , cls : Type , cls_filtered : Optional [Type ] = None ) -> List :
172157 """
@@ -267,6 +252,22 @@ def no_regularization_index_list(self) -> List[int]:
267252
268253 return no_regularization_index_list
269254
255+ @property
256+ def mapper_indices (self ) -> np .ndarray :
257+
258+ if self .preloads .mapper_indices is not None :
259+ return self .preloads .mapper_indices
260+
261+ mapper_indices = []
262+
263+ param_range_list = self .param_range_list_from (cls = AbstractMapper )
264+
265+ for param_range in param_range_list :
266+
267+ mapper_indices += range (param_range [0 ], param_range [1 ])
268+
269+ return np .array (mapper_indices )
270+
270271 @property
271272 def mask (self ) -> Array2D :
272273 return self .data .mask
@@ -354,19 +355,14 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]:
354355 regularization it is bypassed.
355356 """
356357
357- regularization_matrix = self .regularization_matrix
358-
359358 if self .all_linear_obj_have_regularization :
360- return regularization_matrix
359+ return self . regularization_matrix
361360
362- regularization_matrix = np .delete (
363- regularization_matrix , self .no_regularization_index_list , 0
364- )
365- regularization_matrix = np .delete (
366- regularization_matrix , self .no_regularization_index_list , 1
367- )
361+ # ids of values which are on edge so zero-d and not solved for.
362+ ids_to_keep = self .mapper_indices
368363
369- return regularization_matrix
364+ # Zero rows and columns in the matrix we want to ignore
365+ return self .regularization_matrix [ids_to_keep ][:, ids_to_keep ]
370366
371367 @cached_property
372368 def curvature_reg_matrix (self ) -> np .ndarray :
@@ -381,55 +377,31 @@ def curvature_reg_matrix(self) -> np.ndarray:
381377 if not self .has (cls = AbstractRegularization ):
382378 return self .curvature_matrix
383379
384- if len (self .regularization_list ) == 1 :
385- curvature_matrix = self .curvature_matrix
386- curvature_matrix += self .regularization_matrix
387-
388- del self .__dict__ ["curvature_matrix" ]
389-
390- return curvature_matrix
391-
392- return np .add (self .curvature_matrix , self .regularization_matrix )
380+ return jnp .add (self .curvature_matrix , self .regularization_matrix )
393381
394382 @cached_property
395- def curvature_reg_matrix_reduced (self ) -> np .ndarray :
383+ def curvature_reg_matrix_reduced (self ) -> Optional [ np .ndarray ] :
396384 """
397- The linear system of equations solves for F + regularization_coefficient*H, which is computed below.
385+ The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the
386+ linear algebra system we solve for using D and F above and is given by
387+ equation (12) in https://arxiv.org/pdf/astro-ph/0302587.pdf.
398388
399- This is the curvature reg matrix for only the mappers, which is necessary for computing the log det
400- term without the linear light profiles included.
389+ A complete description of regularization is given in the `regularization.py` and `regularization_util.py`
390+ modules.
391+
392+ For multiple mappers, the regularization matrix is computed as the block diagonal of each individual mapper.
393+ The scipy function `block_diag` has an overhead associated with it and if there is only one mapper and
394+ regularization it is bypassed.
401395 """
396+
402397 if self .all_linear_obj_have_regularization :
403398 return self .curvature_reg_matrix
404399
405- curvature_reg_matrix = self .curvature_reg_matrix
400+ # ids of values which are on edge so zero-d and not solved for.
401+ ids_to_keep = self .mapper_indices
406402
407- curvature_reg_matrix = np .delete (
408- curvature_reg_matrix , self .no_regularization_index_list , 0
409- )
410- curvature_reg_matrix = np .delete (
411- curvature_reg_matrix , self .no_regularization_index_list , 1
412- )
413-
414- return curvature_reg_matrix
415-
416- @property
417- def mapper_zero_pixel_list (self ) -> np .ndarray :
418- mapper_zero_pixel_list = []
419- param_range_list = self .param_range_list_from (cls = LinearObj )
420- for param_range , linear_obj in zip (param_range_list , self .linear_obj_list ):
421- if isinstance (linear_obj , AbstractMapper ):
422- mapping_matrix_for_image_pixels_source_zero = linear_obj .mapping_matrix [
423- self .settings .image_pixels_source_zero
424- ]
425- source_pixels_zero = (
426- np .sum (mapping_matrix_for_image_pixels_source_zero != 0 , axis = 0 )
427- != 0
428- )
429- mapper_zero_pixel_list .append (
430- np .where (source_pixels_zero == True )[0 ] + param_range [0 ]
431- )
432- return mapper_zero_pixel_list
403+ # Zero rows and columns in the matrix we want to ignore
404+ return self .curvature_reg_matrix [ids_to_keep ][:, ids_to_keep ]
433405
434406 @cached_property
435407 def reconstruction (self ) -> np .ndarray :
@@ -448,51 +420,36 @@ def reconstruction(self) -> np.ndarray:
448420 ZTx := np.dot(Z.T, x)
449421 """
450422 if self .settings .use_positive_only_solver :
451- """
452- For the new implementation, we now need to take out the cols and rows of
453- the curvature_reg_matrix that corresponds to the parameters we force to be 0.
454- Similar for the data vector.
455423
456- What we actually doing is that we have set the correspoding cols of the Z to be 0.
457- As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
458- And the data_vector = ZTx, so the corresponding row is also taken out.
459- """
424+ if self .preloads .source_pixel_zeroed_indices is not None :
460425
461- if (
462- self .has (cls = AbstractMapper )
463- and self .settings .force_edge_pixels_to_zeros
464- ):
426+ # ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
427+ ids_to_keep = self .preloads .source_pixel_zeroed_indices_to_keep
465428
466- ids_zeros = jnp .array (self .mapper_edge_pixel_list , dtype = int )
429+ # Use advanced indexing to select rows/columns
430+ data_vector = self .data_vector [ids_to_keep ]
431+ curvature_reg_matrix = self .curvature_reg_matrix [ids_to_keep ][
432+ :, ids_to_keep
433+ ]
467434
468- values_to_solve = jnp .ones (
469- self .curvature_reg_matrix .shape [0 ], dtype = bool
435+ # Perform reconstruction via fnnls
436+ reconstruction_partial = (
437+ inversion_util .reconstruction_positive_only_from (
438+ data_vector = data_vector ,
439+ curvature_reg_matrix = curvature_reg_matrix ,
440+ settings = self .settings ,
441+ )
470442 )
471- values_to_solve = values_to_solve .at [ids_zeros ].set (False )
472-
473- data_vector_input = self .data_vector [values_to_solve ]
474443
475- curvature_reg_matrix_input = self .curvature_reg_matrix [
476- values_to_solve , :
477- ][:, values_to_solve ]
444+ # Allocate full solution array
445+ reconstruction = jnp .zeros (self .data_vector .shape [0 ])
478446
479- # Get the values to assign (must be a JAX array)
480- reconstruction = inversion_util .reconstruction_positive_only_from (
481- data_vector = data_vector_input ,
482- curvature_reg_matrix = curvature_reg_matrix_input ,
483- settings = self .settings ,
447+ # Scatter the partial solution back to the full shape
448+ reconstruction = reconstruction .at [ids_to_keep ].set (
449+ reconstruction_partial
484450 )
485451
486- # Allocate JAX array
487- solutions = jnp .zeros (self .curvature_reg_matrix .shape [0 ])
488-
489- # Get indices where True
490- indices = jnp .where (values_to_solve )[0 ]
491-
492- # Set reconstruction values at those indices
493- solutions = solutions .at [indices ].set (reconstruction )
494-
495- return solutions
452+ return reconstruction
496453
497454 else :
498455
@@ -522,7 +479,11 @@ def reconstruction_reduced(self) -> np.ndarray:
522479 if self .all_linear_obj_have_regularization :
523480 return self .reconstruction
524481
525- return np .delete (self .reconstruction , self .no_regularization_index_list , axis = 0 )
482+ # ids of values which are on edge so zero-d and not solved for.
483+ ids_to_keep = self .mapper_indices
484+
485+ # Zero rows and columns in the matrix we want to ignore
486+ return self .reconstruction [ids_to_keep ]
526487
527488 @property
528489 def reconstruction_dict (self ) -> Dict [LinearObj , np .ndarray ]:
@@ -665,9 +626,9 @@ def regularization_term(self) -> float:
665626 if not self .has (cls = AbstractRegularization ):
666627 return 0.0
667628
668- return np .matmul (
629+ return jnp .matmul (
669630 self .reconstruction_reduced .T ,
670- np .matmul (self .regularization_matrix_reduced , self .reconstruction_reduced ),
631+ jnp .matmul (self .regularization_matrix_reduced , self .reconstruction_reduced ),
671632 )
672633
673634 @cached_property
@@ -682,7 +643,9 @@ def log_det_curvature_reg_matrix_term(self) -> float:
682643
683644 try :
684645 return 2.0 * np .sum (
685- np .log (np .diag (np .linalg .cholesky (self .curvature_reg_matrix_reduced )))
646+ jnp .log (
647+ jnp .diag (jnp .linalg .cholesky (self .curvature_reg_matrix_reduced ))
648+ )
686649 )
687650 except np .linalg .LinAlgError as e :
688651 raise exc .InversionException () from e
0 commit comments