Skip to content

Commit 7844633

Browse files
authored
Merge pull request #220 from Jammy2211/feature/remove_preloads
Feature/remove preloads
2 parents 4f88616 + 55a2246 commit 7844633

33 files changed

+407
-258
lines changed

autoarray/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from .operators.contour import Grid2DContour
6565
from .layout.layout import Layout1D
6666
from .layout.layout import Layout2D
67-
from .preloads import Preloads
6867
from .structures.arrays.uniform_1d import Array1D
6968
from .structures.arrays.uniform_2d import Array2D
7069
from .structures.arrays.rgb import Array2DRGB

autoarray/config/general.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ psf:
55
inversion:
66
check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same.
77
use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion.
8+
use_edge_zeroed_pixels : true # If True, the edge pixels of a pixelization are set to zero, which prevents unphysical values in the reconstructed solution at the edge of the pixelization.
89
no_regularization_add_to_curvature_diag_value : 1.0e-3 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
910
use_border_relocator: false # If True, by default a pixelization's border is used to relocate all pixels outside its border to the border.
1011
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.

autoarray/fixtures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def make_delaunay_mapper_9_3x3():
410410
pixel_scales=1.0,
411411
)
412412

413-
mesh = aa.mesh.Delaunay()
413+
mesh = aa.mesh.Delaunay(pixels=9)
414414

415415
interpolator = mesh.interpolator_from(
416416
source_plane_data_grid=make_grid_2d_sub_2_7x7(),
@@ -443,7 +443,7 @@ def make_knn_mapper_9_3x3():
443443
pixel_scales=1.0,
444444
)
445445

446-
mesh = aa.mesh.KNearestNeighbor(split_neighbor_division=1)
446+
mesh = aa.mesh.KNearestNeighbor(pixels=9, split_neighbor_division=1)
447447

448448
interpolator = mesh.interpolator_from(
449449
source_plane_data_grid=make_grid_2d_sub_2_7x7(),

autoarray/inversion/inversion/abstract.py

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from autoarray.inversion.mappers.abstract import Mapper
1313
from autoarray.inversion.regularization.abstract import AbstractRegularization
1414
from autoarray.settings import Settings
15-
from autoarray.preloads import Preloads
1615
from autoarray.structures.arrays.uniform_2d import Array2D
1716
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
1817
from autoarray.structures.visibilities import Visibilities
@@ -27,7 +26,6 @@ def __init__(
2726
dataset: Union[Imaging, Interferometer, DatasetInterface],
2827
linear_obj_list: List[LinearObj],
2928
settings: Settings = None,
30-
preloads: Preloads = None,
3129
xp=np,
3230
):
3331
"""
@@ -74,8 +72,6 @@ def __init__(
7472

7573
self.settings = settings or Settings()
7674

77-
self.preloads = preloads or Preloads()
78-
7975
self.use_jax = xp is not np
8076

8177
@property
@@ -234,9 +230,6 @@ def no_regularization_index_list(self) -> List[int]:
234230
@property
235231
def mapper_indices(self) -> np.ndarray:
236232

237-
if self.preloads.mapper_indices is not None:
238-
return self.preloads.mapper_indices
239-
240233
mapper_indices = []
241234

242235
param_range_list = self.param_range_list_from(cls=Mapper)
@@ -386,6 +379,107 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
386379
# Zero rows and columns in the matrix we want to ignore
387380
return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]
388381

382+
@cached_property
383+
def zeroed_ids_to_keep(self):
384+
"""
385+
Return the **positive global indices** of linear parameters that should be
386+
kept (solved for) in the inversion, accounting for **zeroed pixel indices**
387+
from one or more mappers.
388+
389+
---------------------------------------------------------------------------
390+
Parameter vector layout
391+
---------------------------------------------------------------------------
392+
This method assumes the full linear parameter vector is ordered as:
393+
394+
[ non-pixel linear objects ][ mapper_0 pixels ][ mapper_1 pixels ] ... [ mapper_M pixels ]
395+
396+
where:
397+
398+
- *Non-pixel linear objects* include quantities such as analytic light
399+
profiles, regularization amplitudes, etc.
400+
- Each mapper contributes a contiguous block of pixel-based linear parameters.
401+
- The concatenated pixel blocks occupy the **final** entries of the parameter
402+
vector, with total length:
403+
404+
total_pixels = sum(mapper.mesh.pixels for mapper in mappers)
405+
406+
---------------------------------------------------------------------------
407+
Zeroed pixel convention
408+
---------------------------------------------------------------------------
409+
For each mapper:
410+
411+
- `mapper.mesh.zeroed_pixels` must be a 1D array of **positive, mesh-local**
412+
pixel indices in the range `[0, mapper.mesh.pixels - 1]`.
413+
- These indices identify pixels that should be **excluded** from the linear
414+
solve (e.g. edge pixels, masked regions, or padding pixels).
415+
- Indexing is defined purely within the mapper’s own pixelization (e.g.
416+
row-major flattening for rectangular meshes).
417+
418+
This method converts all mesh-local zeroed pixel indices into **global
419+
parameter indices**, correctly offsetting for:
420+
- the presence of non-pixel linear objects at the start of the vector
421+
- the cumulative pixel counts of preceding mappers
422+
423+
---------------------------------------------------------------------------
424+
Backend and implementation details
425+
---------------------------------------------------------------------------
426+
- The implementation is backend-agnostic and supports both NumPy and JAX via
427+
`self._xp`.
428+
- The returned indices are **positive global indices**, suitable for advanced
429+
indexing of:
430+
- `self.data_vector`
431+
- `self.curvature_reg_matrix`
432+
- When using JAX, this method avoids backend-incompatible operations and
433+
preserves JIT compatibility under the same constraints as the rest of the
434+
inversion pipeline.
435+
436+
Returns
437+
-------
438+
array-like
439+
A 1D array of **positive global indices**, sorted in ascending order,
440+
corresponding to linear parameters that should be kept in the inversion.
441+
"""
442+
443+
mapper_list = self.cls_list_from(cls=Mapper)
444+
445+
n_total = int(self.total_params)
446+
447+
pixels_per_mapper = [int(m.mesh.pixels) for m in mapper_list]
448+
total_pixels = int(sum(pixels_per_mapper))
449+
450+
# Global start index of concatenated pixel block
451+
pixel_start = n_total - total_pixels
452+
453+
# Total number of zeroed pixels across all mappers (Python int => static)
454+
total_zeroed = int(sum(len(m.mesh.zeroed_pixels) for m in mapper_list))
455+
n_keep = int(n_total - total_zeroed)
456+
457+
# Build global indices-to-zero across all mappers
458+
zeros_global_list = []
459+
offset = 0
460+
for m, n_pix in zip(mapper_list, pixels_per_mapper):
461+
zeros_local = self._xp.asarray(m.mesh.zeroed_pixels, dtype=self._xp.int32)
462+
zeros_global_list.append(pixel_start + offset + zeros_local)
463+
offset += n_pix
464+
465+
zeros_global = (
466+
self._xp.concatenate(zeros_global_list)
467+
if len(zeros_global_list) > 0
468+
else self._xp.asarray([], dtype=self._xp.int32)
469+
)
470+
471+
keep = self._xp.ones((n_total,), dtype=bool)
472+
473+
if self._xp is np:
474+
keep[zeros_global] = False
475+
keep_ids = self._xp.nonzero(keep)[0]
476+
477+
else:
478+
keep = keep.at[zeros_global].set(False)
479+
keep_ids = self._xp.nonzero(keep, size=n_keep)[0]
480+
481+
return keep_ids
482+
389483
@cached_property
390484
def reconstruction(self) -> np.ndarray:
391485
"""
@@ -405,16 +499,13 @@ def reconstruction(self) -> np.ndarray:
405499

406500
if self.settings.use_positive_only_solver:
407501

408-
if self.preloads.source_pixel_zeroed_indices is not None:
409-
410-
# ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
411-
ids_to_keep = self.preloads.source_pixel_zeroed_indices_to_keep
502+
if self.settings.use_edge_zeroed_pixels and self.has(cls=Mapper):
412503

413504
# Use advanced indexing to select rows/columns
414-
data_vector = self.data_vector[ids_to_keep]
415-
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][
416-
:, ids_to_keep
417-
]
505+
data_vector = self.data_vector[self.zeroed_ids_to_keep]
506+
curvature_reg_matrix = self.curvature_reg_matrix[
507+
self.zeroed_ids_to_keep
508+
][:, self.zeroed_ids_to_keep]
418509

419510
# Perform reconstruction via fnnls
420511
reconstruction_partial = (
@@ -431,11 +522,11 @@ def reconstruction(self) -> np.ndarray:
431522

432523
# Scatter the partial solution back to the full shape
433524
if self._xp.__name__.startswith("jax"):
434-
reconstruction = reconstruction.at[ids_to_keep].set(
525+
reconstruction = reconstruction.at[self.zeroed_ids_to_keep].set(
435526
reconstruction_partial
436527
)
437528
else:
438-
reconstruction[ids_to_keep] = reconstruction_partial
529+
reconstruction[self.zeroed_ids_to_keep] = reconstruction_partial
439530

440531
return reconstruction
441532

autoarray/inversion/inversion/factory.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@
2424
InversionImagingSparse,
2525
)
2626
from autoarray.settings import Settings
27-
from autoarray.preloads import Preloads
2827
from autoarray.structures.arrays.uniform_2d import Array2D
2928

3029

3130
def inversion_from(
3231
dataset: Union[Imaging, Interferometer, DatasetInterface],
3332
linear_obj_list: List[LinearObj],
3433
settings: Settings = None,
35-
preloads: Preloads = None,
3634
xp=np,
3735
):
3836
"""
@@ -68,7 +66,6 @@ def inversion_from(
6866
dataset=dataset,
6967
linear_obj_list=linear_obj_list,
7068
settings=settings,
71-
preloads=preloads,
7269
xp=xp,
7370
)
7471

@@ -81,7 +78,6 @@ def inversion_imaging_from(
8178
dataset,
8279
linear_obj_list: List[LinearObj],
8380
settings: Settings = None,
84-
preloads: Preloads = None,
8581
xp=np,
8682
):
8783
"""
@@ -133,23 +129,20 @@ def inversion_imaging_from(
133129
dataset=dataset,
134130
linear_obj_list=linear_obj_list,
135131
settings=settings,
136-
preloads=preloads,
137132
xp=xp,
138133
)
139134

140135
return InversionImagingSparse(
141136
dataset=dataset,
142137
linear_obj_list=linear_obj_list,
143138
settings=settings,
144-
preloads=preloads,
145139
xp=xp,
146140
)
147141

148142
return InversionImagingMapping(
149143
dataset=dataset,
150144
linear_obj_list=linear_obj_list,
151145
settings=settings,
152-
preloads=preloads,
153146
xp=xp,
154147
)
155148

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from autoarray.inversion.inversion.abstract import AbstractInversion
99
from autoarray.inversion.linear_obj.linear_obj import LinearObj
1010
from autoarray.settings import Settings
11-
from autoarray.preloads import Preloads
1211

1312
from autoarray.inversion.inversion.imaging import inversion_imaging_util
1413

@@ -19,7 +18,6 @@ def __init__(
1918
dataset: Union[Imaging, DatasetInterface],
2019
linear_obj_list: List[LinearObj],
2120
settings: Settings = None,
22-
preloads: Preloads = None,
2321
xp=np,
2422
):
2523
"""
@@ -67,7 +65,6 @@ def __init__(
6765
dataset=dataset,
6866
linear_obj_list=linear_obj_list,
6967
settings=settings,
70-
preloads=preloads,
7168
xp=xp,
7269
)
7370

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from autoarray.inversion.linear_obj.linear_obj import LinearObj
1010
from autoarray.inversion.mappers.abstract import Mapper
1111
from autoarray.settings import Settings
12-
from autoarray.preloads import Preloads
1312
from autoarray.structures.arrays.uniform_2d import Array2D
1413

1514
from autoarray.inversion.inversion import inversion_util
@@ -22,7 +21,6 @@ def __init__(
2221
dataset: Union[Imaging, DatasetInterface],
2322
linear_obj_list: List[LinearObj],
2423
settings: Settings = None,
25-
preloads: Preloads = None,
2624
xp=np,
2725
):
2826
"""
@@ -49,7 +47,6 @@ def __init__(
4947
dataset=dataset,
5048
linear_obj_list=linear_obj_list,
5149
settings=settings,
52-
preloads=preloads,
5350
xp=xp,
5451
)
5552

autoarray/inversion/inversion/imaging/sparse.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from autoarray.settings import Settings
1111
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
1212
from autoarray.inversion.mappers.abstract import Mapper
13-
from autoarray.preloads import Preloads
1413
from autoarray.structures.arrays.uniform_2d import Array2D
1514

1615
from autoarray.inversion.inversion.imaging import inversion_imaging_util
@@ -22,7 +21,6 @@ def __init__(
2221
dataset: Union[Imaging, DatasetInterface],
2322
linear_obj_list: List[LinearObj],
2423
settings: Settings = None,
25-
preloads: Preloads = None,
2624
xp=np,
2725
):
2826
"""
@@ -49,7 +47,6 @@ def __init__(
4947
dataset=dataset,
5048
linear_obj_list=linear_obj_list,
5149
settings=settings,
52-
preloads=preloads,
5350
xp=xp,
5451
)
5552

autoarray/inversion/inversion/imaging_numba/sparse.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from autoarray.settings import Settings
1111
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
1212
from autoarray.inversion.mappers.abstract import Mapper
13-
from autoarray.preloads import Preloads
1413
from autoarray.structures.arrays.uniform_2d import Array2D
1514

1615
from autoarray.inversion.inversion.imaging_numba import inversion_imaging_numba_util
@@ -22,7 +21,6 @@ def __init__(
2221
dataset: Union[Imaging, DatasetInterface],
2322
linear_obj_list: List[LinearObj],
2423
settings: Settings = None,
25-
preloads: Preloads = None,
2624
xp=np,
2725
):
2826
"""
@@ -49,7 +47,6 @@ def __init__(
4947
dataset=dataset,
5048
linear_obj_list=linear_obj_list,
5149
settings=settings,
52-
preloads=preloads,
5350
xp=xp,
5451
)
5552

autoarray/inversion/mappers/abstract.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(
2323
regularization: Optional[AbstractRegularization] = None,
2424
settings: Settings = None,
2525
image_plane_mesh_grid=None,
26-
preloads=None,
2726
xp=np,
2827
):
2928
"""
@@ -96,7 +95,7 @@ def __init__(
9695
self.interpolator = interpolator
9796

9897
self.image_plane_mesh_grid = image_plane_mesh_grid
99-
self.preloads = preloads
98+
10099
self.settings = settings or Settings()
101100

102101
@property
@@ -111,6 +110,10 @@ def pixels(self) -> int:
111110
def mask(self):
112111
return self.source_plane_data_grid.mask
113112

113+
@property
114+
def mesh(self):
115+
return self.interpolator.mesh
116+
114117
@property
115118
def mesh_geometry(self):
116119
return self.interpolator.mesh_geometry

0 commit comments

Comments
 (0)