1212from autoarray .inversion .mappers .abstract import Mapper
1313from autoarray .inversion .regularization .abstract import AbstractRegularization
1414from autoarray .settings import Settings
15- from autoarray .preloads import Preloads
1615from autoarray .structures .arrays .uniform_2d import Array2D
1716from autoarray .structures .grids .irregular_2d import Grid2DIrregular
1817from autoarray .structures .visibilities import Visibilities
@@ -27,7 +26,6 @@ def __init__(
2726 dataset : Union [Imaging , Interferometer , DatasetInterface ],
2827 linear_obj_list : List [LinearObj ],
2928 settings : Settings = None ,
30- preloads : Preloads = None ,
3129 xp = np ,
3230 ):
3331 """
@@ -74,8 +72,6 @@ def __init__(
7472
7573 self .settings = settings or Settings ()
7674
77- self .preloads = preloads or Preloads ()
78-
7975 self .use_jax = xp is not np
8076
8177 @property
@@ -234,9 +230,6 @@ def no_regularization_index_list(self) -> List[int]:
234230 @property
235231 def mapper_indices (self ) -> np .ndarray :
236232
237- if self .preloads .mapper_indices is not None :
238- return self .preloads .mapper_indices
239-
240233 mapper_indices = []
241234
242235 param_range_list = self .param_range_list_from (cls = Mapper )
@@ -386,6 +379,107 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
386379 # Zero rows and columns in the matrix we want to ignore
387380 return self .curvature_reg_matrix [ids_to_keep ][:, ids_to_keep ]
388381
382+ @cached_property
383+ def zeroed_ids_to_keep (self ):
384+ """
385+ Return the **positive global indices** of linear parameters that should be
386+ kept (solved for) in the inversion, accounting for **zeroed pixel indices**
387+ from one or more mappers.
388+
389+ ---------------------------------------------------------------------------
390+ Parameter vector layout
391+ ---------------------------------------------------------------------------
392+ This method assumes the full linear parameter vector is ordered as:
393+
394+ [ non-pixel linear objects ][ mapper_0 pixels ][ mapper_1 pixels ] ... [ mapper_M pixels ]
395+
396+ where:
397+
398+ - *Non-pixel linear objects* include quantities such as analytic light
399+ profiles, regularization amplitudes, etc.
400+ - Each mapper contributes a contiguous block of pixel-based linear parameters.
401+ - The concatenated pixel blocks occupy the **final** entries of the parameter
402+ vector, with total length:
403+
404+ total_pixels = sum(mapper.mesh.pixels for mapper in mappers)
405+
406+ ---------------------------------------------------------------------------
407+ Zeroed pixel convention
408+ ---------------------------------------------------------------------------
409+ For each mapper:
410+
411+ - `mapper.mesh.zeroed_pixels` must be a 1D array of **positive, mesh-local**
412+ pixel indices in the range `[0, mapper.mesh.pixels - 1]`.
413+ - These indices identify pixels that should be **excluded** from the linear
414+ solve (e.g. edge pixels, masked regions, or padding pixels).
415+ - Indexing is defined purely within the mapper’s own pixelization (e.g.
416+ row-major flattening for rectangular meshes).
417+
418+ This method converts all mesh-local zeroed pixel indices into **global
419+ parameter indices**, correctly offsetting for:
420+ - the presence of non-pixel linear objects at the start of the vector
421+ - the cumulative pixel counts of preceding mappers
422+
423+ ---------------------------------------------------------------------------
424+ Backend and implementation details
425+ ---------------------------------------------------------------------------
426+ - The implementation is backend-agnostic and supports both NumPy and JAX via
427+ `self._xp`.
428+ - The returned indices are **positive global indices**, suitable for advanced
429+ indexing of:
430+ - `self.data_vector`
431+ - `self.curvature_reg_matrix`
432+ - When using JAX, this method avoids backend-incompatible operations and
433+ preserves JIT compatibility under the same constraints as the rest of the
434+ inversion pipeline.
435+
436+ Returns
437+ -------
438+ array-like
439+ A 1D array of **positive global indices**, sorted in ascending order,
440+ corresponding to linear parameters that should be kept in the inversion.
441+ """
442+
443+ mapper_list = self .cls_list_from (cls = Mapper )
444+
445+ n_total = int (self .total_params )
446+
447+ pixels_per_mapper = [int (m .mesh .pixels ) for m in mapper_list ]
448+ total_pixels = int (sum (pixels_per_mapper ))
449+
450+ # Global start index of concatenated pixel block
451+ pixel_start = n_total - total_pixels
452+
453+ # Total number of zeroed pixels across all mappers (Python int => static)
454+ total_zeroed = int (sum (len (m .mesh .zeroed_pixels ) for m in mapper_list ))
455+ n_keep = int (n_total - total_zeroed )
456+
457+ # Build global indices-to-zero across all mappers
458+ zeros_global_list = []
459+ offset = 0
460+ for m , n_pix in zip (mapper_list , pixels_per_mapper ):
461+ zeros_local = self ._xp .asarray (m .mesh .zeroed_pixels , dtype = self ._xp .int32 )
462+ zeros_global_list .append (pixel_start + offset + zeros_local )
463+ offset += n_pix
464+
465+ zeros_global = (
466+ self ._xp .concatenate (zeros_global_list )
467+ if len (zeros_global_list ) > 0
468+ else self ._xp .asarray ([], dtype = self ._xp .int32 )
469+ )
470+
471+ keep = self ._xp .ones ((n_total ,), dtype = bool )
472+
473+ if self ._xp is np :
474+ keep [zeros_global ] = False
475+ keep_ids = self ._xp .nonzero (keep )[0 ]
476+
477+ else :
478+ keep = keep .at [zeros_global ].set (False )
479+ keep_ids = self ._xp .nonzero (keep , size = n_keep )[0 ]
480+
481+ return keep_ids
482+
389483 @cached_property
390484 def reconstruction (self ) -> np .ndarray :
391485 """
@@ -405,16 +499,13 @@ def reconstruction(self) -> np.ndarray:
405499
406500 if self .settings .use_positive_only_solver :
407501
408- if self .preloads .source_pixel_zeroed_indices is not None :
409-
410- # ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
411- ids_to_keep = self .preloads .source_pixel_zeroed_indices_to_keep
502+ if self .settings .use_edge_zeroed_pixels and self .has (cls = Mapper ):
412503
413504 # Use advanced indexing to select rows/columns
414- data_vector = self .data_vector [ids_to_keep ]
415- curvature_reg_matrix = self .curvature_reg_matrix [ids_to_keep ][
416- :, ids_to_keep
417- ]
505+ data_vector = self .data_vector [self . zeroed_ids_to_keep ]
506+ curvature_reg_matrix = self .curvature_reg_matrix [
507+ self . zeroed_ids_to_keep
508+ ][:, self . zeroed_ids_to_keep ]
418509
419510 # Perform reconstruction via fnnls
420511 reconstruction_partial = (
@@ -431,11 +522,11 @@ def reconstruction(self) -> np.ndarray:
431522
432523 # Scatter the partial solution back to the full shape
433524 if self ._xp .__name__ .startswith ("jax" ):
434- reconstruction = reconstruction .at [ids_to_keep ].set (
525+ reconstruction = reconstruction .at [self . zeroed_ids_to_keep ].set (
435526 reconstruction_partial
436527 )
437528 else :
438- reconstruction [ids_to_keep ] = reconstruction_partial
529+ reconstruction [self . zeroed_ids_to_keep ] = reconstruction_partial
439530
440531 return reconstruction
441532
0 commit comments