11import logging
22
3+ import jax .numpy as jnp
34import numpy as np
45
56logger = logging .getLogger (__name__ )
910
1011class Preloads :
1112
12- def __init__ (self , mapper_indices : np .ndarray = None ):
13+ def __init__ (self , mapper_indices : np .ndarray = None , source_pixel_zeroed_indices : np . ndarray = None ):
1314 """
14- Preload in memory arrays and matrices used to perform pixelized linear inversions, for both key functionality
15- and speeding up the run-time of the inversion .
15+ Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
16+ and compatibility with JAX .
1617
17- Certain preloading arrays (e.g. `mapper_indices`) are stored here because JAX requires that they are
18- known and defined as static arrays before sampling. During each inversion, the preloads will be inspected
19- for these fixed arrays and used to change matrix shapes in an identical way for every likelihood evaluation.
20-
21- Other preloading arrays are used purely to speed up the run-time of the inversion, such as
22- the `curvature_matrix_preload` array. For certain models (e.g. if the source model is fixed and only the
23- lens light is being fitted for), certain quadrants of the `curvature_matrix` are fixed
24- for every likelihood evaluation, meaning that they can be preloaded and used to speed up the inversion.
18+ Some arrays (e.g. `mapper_indices`) are required to be defined before sampling begins, because JAX demands
19+ that input shapes remain static. These are used during each inversion to ensure consistent matrix shapes
20+ for all likelihood evaluations.
2521
22+ Other arrays (e.g. parts of the curvature matrix) are preloaded purely to improve performance. In cases where
23+ the source model is fixed (e.g. when fitting only the lens light), sections of the curvature matrix do not
24+ change and can be reused, avoiding redundant computation.
2625
2726 Parameters
2827 ----------
2928 mapper_indices
30- The integer indexes of the mapper pixels in a pixeized inversion, which separate their indexes from those
31- of linear light profiles in the inversion. This is used to extract `_reduced`
32- matrices (e.g. `curvature_matrix_reduced`) to compute the `log_evidence` terms of the pixelized inversion
33- likelihood function.
29+ The integer indices of mapper pixels in the inversion. Used to extract reduced matrices (e.g.
30+ `curvature_matrix_reduced`) that compute the pixelized inversion's log evidence term, where the indicies
31+ are requirred to separate the rows and columns of matrices from linear light profiles.
32+ source_pixel_zeroed_indices
33+ Indices of source pixels that should be set to zero in the reconstruction. These typically correspond to
34+ outer-edge source-plane regions with no image-plane mapping (e.g. outside a circular mask), helping
35+ separate the lens light from the pixelized source model.
3436 """
3537
36- self .mapper_indices = mapper_indices
38+ self .mapper_indices = None
39+ self .source_pixel_zeroed_indices = None
40+ self .source_pixel_zeroed_indices_to_keep = None
41+
42+ if mapper_indices is not None :
43+
44+ self .mapper_indices = jnp .array (mapper_indices )
45+
46+ if source_pixel_zeroed_indices is not None :
47+
48+ self .source_pixel_zeroed_indices = jnp .array (source_pixel_zeroed_indices )
49+
50+ ids_zeros = jnp .array (source_pixel_zeroed_indices , dtype = int )
51+
52+ values_to_solve = jnp .ones (np .max (mapper_indices ), dtype = bool )
53+ values_to_solve = values_to_solve .at [ids_zeros ].set (False )
54+
55+ # Get the indices where values_to_solve is True
56+ self .source_pixel_zeroed_indices_to_keep = jnp .where (values_to_solve )[0 ]
0 commit comments