Skip to content

Commit d30cb17

Browse files
Jammy2211Jammy2211
authored andcommitted
Fix numba code
1 parent b8efd47 commit d30cb17

File tree

5 files changed

+81
-67
lines changed

5 files changed

+81
-67
lines changed

autoarray/inversion/inversion/imaging_numba/inversion_imaging_numba_util.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -651,57 +651,85 @@ def curvature_matrix_off_diags_via_sparse_operator_from(
651651

652652

653653
@numba_util.jit()
654-
def curvature_matrix_off_diags_via_data_linear_func_matrix_from(
655-
data_linear_func_matrix: np.ndarray,
654+
def curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from(
656655
data_to_pix_unique: np.ndarray,
657656
data_weights: np.ndarray,
658657
pix_lengths: np.ndarray,
659658
pix_pixels: int,
660-
):
659+
curvature_weights: np.ndarray, # shape (n_unmasked, n_funcs)
660+
mask: np.ndarray, # shape (ny, nx), bool
661+
psf_kernel: np.ndarray, # shape (ky, kx)
662+
) -> np.ndarray:
661663
"""
662-
Returns the off diagonal terms in the curvature matrix `F` (see Warren & Dye 2003) between a mapper object
663-
and a linear func object, using the preloaded `data_linear_func_matrix` of the values of the linear functions.
664-
664+
Returns the off-diagonal terms in the curvature matrix `F` (see Warren & Dye 2003)
665+
between a mapper object and a linear func object, using the unique mappings between
666+
data pixels and pixelization pixels.
665667
666-
If a linear function in an inversion is fixed, its values can be evaluated and preloaded beforehand. For every
667-
data pixel, the PSF convolution with this preloaded linear function can also be preloaded, in a matrix of
668-
shape [data_pixels, 1].
668+
This version applies the PSF directly as a 2D convolution kernel. The curvature
669+
weights of the linear function object (values of the linear function divided by the
670+
noise-map squared) are expanded into the native 2D image grid, convolved with the PSF
671+
kernel, and then remapped back to the 1D slim representation.
669672
670-
When mapper objects and linear functions are used simultaneously in an inversion, this preloaded matrix
671-
significantly speed up the computation of their off-diagonal terms in the curvature matrix.
672-
673-
This function performs this efficient calcluation via the preloaded `data_linear_func_matrix`.
673+
For each unique mapping between a data pixel and a pixelization pixel, the convolved
674+
curvature weights at that data pixel are multiplied by the mapping weights and
675+
accumulated into the off-diagonal block of the curvature matrix. This accounts for
676+
sub-pixel mappings between data pixels and pixelization pixels.
674677
675678
Parameters
676679
----------
677-
data_linear_func_matrix
678-
A matrix of shape [data_pixels, total_fixed_linear_functions] that for each data pixel, maps it to the sum of
679-
the values of a linear object function convolved with the PSF kernel at the data pixel.
680680
data_to_pix_unique
681-
The indexes of all pixels that each data pixel maps to (see the `Mapper` object).
681+
An array that maps every data pixel index (e.g. the masked image pixel indexes in 1D)
682+
to its unique set of pixelization pixel indexes (see `data_slim_to_pixelization_unique_from`).
682683
data_weights
683-
The weights of all pixels that each data pixel maps to (see the `Mapper` object).
684+
For every unique mapping between a set of data sub-pixels and a pixelization pixel,
685+
the weight of this mapping based on the number of sub-pixels that map to the pixelization pixel.
684686
pix_lengths
685-
The number of pixelization pixels that each data pixel maps to (see the `Mapper` object).
687+
A 1D array describing how many unique pixels each data pixel maps to. Used to iterate over
688+
`data_to_pix_unique` and `data_weights`.
686689
pix_pixels
687-
The number of pixelization pixels in the pixelization (see the `Mapper` object).
688-
"""
689-
690-
linear_func_pixels = data_linear_func_matrix.shape[1]
691-
692-
off_diag = np.zeros((pix_pixels, linear_func_pixels))
690+
The total number of pixels in the pixelization that reconstructs the data.
691+
curvature_weights
692+
The operated values of the linear function divided by the noise-map squared, with shape
693+
[n_unmasked_data_pixels, n_linear_func_pixels].
694+
mask
695+
A 2D boolean mask of shape (ny, nx) indicating which pixels are in the data region.
696+
psf_kernel
697+
The PSF kernel in its native 2D form, centered (odd dimensions recommended).
693698
699+
Returns
700+
-------
701+
ndarray
702+
The off-diagonal block of the curvature matrix `F` (see Warren & Dye 2003),
703+
with shape [pix_pixels, n_linear_func_pixels].
704+
"""
694705
data_pixels = data_weights.shape[0]
695-
706+
n_funcs = curvature_weights.shape[1]
707+
ny, nx = mask.shape
708+
709+
# Expand curvature weights into native grid
710+
curvature_native = np.zeros((ny, nx, n_funcs))
711+
unmasked_coords = np.argwhere(~mask)
712+
for i, (y, x) in enumerate(unmasked_coords):
713+
for f in range(n_funcs):
714+
curvature_native[y, x, f] = curvature_weights[i, f]
715+
716+
# Convolve in native space
717+
blurred_native = convolve_with_kernel_native(curvature_native, psf_kernel)
718+
719+
# Map back to slim representation
720+
blurred_slim = np.zeros((data_pixels, n_funcs))
721+
for i, (y, x) in enumerate(unmasked_coords):
722+
for f in range(n_funcs):
723+
blurred_slim[i, f] = blurred_native[y, x, f]
724+
725+
# Accumulate into off_diag
726+
off_diag = np.zeros((pix_pixels, n_funcs))
696727
for data_0 in range(data_pixels):
697728
for pix_0_index in range(pix_lengths[data_0]):
698729
data_0_weight = data_weights[data_0, pix_0_index]
699730
pix_0 = data_to_pix_unique[data_0, pix_0_index]
700-
701-
for linear_index in range(linear_func_pixels):
702-
off_diag[pix_0, linear_index] += (
703-
data_linear_func_matrix[data_0, linear_index] * data_0_weight
704-
)
731+
for f in range(n_funcs):
732+
off_diag[pix_0, f] += data_0_weight * blurred_slim[data_0, f]
705733

706734
return off_diag
707735

autoarray/inversion/inversion/imaging_numba/sparse.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -276,18 +276,16 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
276276
mapper_i = mapper_list[i]
277277
mapper_param_range_i = mapper_param_range_list[i]
278278

279-
diag = (
280-
inversion_imaging_numba_util.curvature_matrix_via_sparse_operator_from(
281-
psf_precision_operator=self.sparse_operator.psf_precision_operator_sparse,
282-
psf_precision_indexes=self.sparse_operator.indexes,
283-
psf_precision_lengths=self.sparse_operator.lengths,
284-
data_to_pix_unique=np.array(
285-
mapper_i.unique_mappings.data_to_pix_unique
286-
),
287-
data_weights=np.array(mapper_i.unique_mappings.data_weights),
288-
pix_lengths=np.array(mapper_i.unique_mappings.pix_lengths),
289-
pix_pixels=mapper_i.params,
290-
)
279+
diag = inversion_imaging_numba_util.curvature_matrix_via_sparse_operator_from(
280+
psf_precision_operator=self.sparse_operator.psf_precision_operator_sparse,
281+
psf_precision_indexes=self.sparse_operator.indexes,
282+
psf_precision_lengths=self.sparse_operator.lengths,
283+
data_to_pix_unique=np.array(
284+
mapper_i.unique_mappings.data_to_pix_unique
285+
),
286+
data_weights=np.array(mapper_i.unique_mappings.data_weights),
287+
pix_lengths=np.array(mapper_i.unique_mappings.pix_lengths),
288+
pix_pixels=mapper_i.params,
291289
)
292290

293291
curvature_matrix[
@@ -425,19 +423,16 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray:
425423
/ self.noise_map[:, None] ** 2
426424
)
427425

428-
print(data_linear_func_matrix)
429-
430-
off_diag = inversion_imaging_numba_util.curvature_matrix_off_diags_via_data_linear_func_matrix_from(
431-
data_linear_func_matrix=data_linear_func_matrix,
426+
off_diag = inversion_imaging_numba_util.curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from(
432427
data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique,
433428
data_weights=mapper.unique_mappings.data_weights,
434429
pix_lengths=mapper.unique_mappings.pix_lengths,
435430
pix_pixels=mapper.params,
431+
curvature_weights=np.array(data_linear_func_matrix),
432+
mask=self.mask.array,
433+
psf_kernel=self.psf.native.array,
436434
)
437435

438-
439-
print(off_diag[0:5, 0:5])
440-
441436
curvature_matrix[
442437
mapper_param_range[0] : mapper_param_range[1],
443438
linear_func_param_range[0] : linear_func_param_range[1],
@@ -470,10 +465,6 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray:
470465
linear_func_param_range_1[0] : linear_func_param_range_1[1],
471466
] = diag
472467

473-
474-
print(curvature_matrix[0, 2])
475-
ffff
476-
477468
return curvature_matrix
478469

479470
@property

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ class InterferometerSparseOperator:
607607
FFT of the curvature preload, shape (2y_shape, 2x_shape), complex.
608608
This is the frequency-domain representation of the W~ operator kernel.
609609
"""
610-
610+
611611
@classmethod
612612
def from_nufft_precision_operator(
613613
cls,

test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,13 @@ def test__curvature_matrix_via_psf_precision_operator_from():
9090
]
9191
)
9292

93-
nufft_precision_operator = aa.util.inversion_interferometer.nufft_precision_operator_from(
94-
noise_map_real=noise_map,
95-
uv_wavelengths=uv_wavelengths,
96-
shape_masked_pixels_2d=(3, 3),
97-
grid_radians_2d=np.array(grid.native),
93+
nufft_precision_operator = (
94+
aa.util.inversion_interferometer.nufft_precision_operator_from(
95+
noise_map_real=noise_map,
96+
uv_wavelengths=uv_wavelengths,
97+
shape_masked_pixels_2d=(3, 3),
98+
grid_radians_2d=np.array(grid.native),
99+
)
98100
)
99101

100102
native_index_for_slim_index = np.array(

test_autoarray/inversion/inversion/test_factory.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,6 @@ def test__inversion_imaging__linear_obj_func_with_sparse_operator(
390390
inversion_sparse_operator.data_vector, 1.0e-4
391391
)
392392

393-
print(inversion_mapping.curvature_matrix[0,2])
394-
print(inversion_mapping.curvature_matrix[0,3])
395-
print(inversion_sparse_operator.curvature_matrix[0,2])
396-
print(inversion_sparse_operator.curvature_matrix[0,3])
397-
ffff
398-
399393
assert inversion_mapping.curvature_matrix == pytest.approx(
400394
inversion_sparse_operator.curvature_matrix, 1.0e-4
401395
)
@@ -559,7 +553,6 @@ def test__inversion_matrices__x2_mappers(
559553
assert inversion.mapped_reconstructed_image[4] == pytest.approx(0.99999998, 1.0e-4)
560554

561555

562-
563556
def test__inversion_imaging__positive_only_solver(masked_imaging_7x7_no_blur):
564557
mask = masked_imaging_7x7_no_blur.mask
565558

0 commit comments

Comments
 (0)