Skip to content

Commit 2220d15

Browse files
authored
Merge pull request #218 from Jammy2211/feature/psf_convolution_refactor
Feature/psf convolution refactor
2 parents 28a2e85 + 45b53ef commit 2220d15

File tree

25 files changed

+741
-1200
lines changed

25 files changed

+741
-1200
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from .inversion.mesh.mesh_geometry.delaunay import MeshGeometryDelaunay
7878
from .inversion.mesh.interpolator.rectangular import InterpolatorRectangular
7979
from .inversion.mesh.interpolator.delaunay import InterpolatorDelaunay
80-
from .structures.arrays.kernel_2d import Kernel2D
80+
from .operators.convolver import Convolver
8181
from .structures.vectors.uniform import VectorYX2D
8282
from .structures.vectors.irregular import VectorYX2DIrregular
8383
from .structures.triangles.abstract import AbstractTriangles

autoarray/dataset/grids.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from autoarray.mask.mask_2d import Mask2D
44
from autoarray.structures.arrays.uniform_2d import Array2D
5-
from autoarray.structures.arrays.kernel_2d import Kernel2D
5+
from autoarray.operators.convolver import Convolver
66
from autoarray.structures.grids.uniform_2d import Grid2D
77

88
from autoarray.inversion.mesh.border_relocator import BorderRelocator
@@ -16,7 +16,7 @@ def __init__(
1616
mask: Mask2D,
1717
over_sample_size_lp: Union[int, Array2D],
1818
over_sample_size_pixelization: Union[int, Array2D],
19-
psf: Optional[Kernel2D] = None,
19+
psf: Optional[Convolver] = None,
2020
):
2121
"""
2222
Contains grids of (y,x) Cartesian coordinates at the centre of every pixel in the dataset's image and
@@ -99,12 +99,10 @@ def blurring(self):
9999
if self.psf is None:
100100
self._blurring = None
101101
else:
102-
try:
103-
self._blurring = self.lp.blurring_grid_via_kernel_shape_from(
104-
kernel_shape_native=self.psf.shape_native,
105-
)
106-
except exc.MaskException:
107-
self._blurring = None
102+
self._blurring = Grid2D.from_mask(
103+
mask=self.psf._state.blurring_mask,
104+
over_sample_size=1,
105+
)
108106

109107
return self._blurring
110108

autoarray/dataset/imaging/dataset.py

Lines changed: 23 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
ImagingSparseOperator,
1010
)
1111
from autoarray.structures.arrays.uniform_2d import Array2D
12-
from autoarray.structures.arrays.kernel_2d import Kernel2D
12+
from autoarray.operators.convolver import ConvolverState
13+
from autoarray.operators.convolver import Convolver
1314
from autoarray.mask.mask_2d import Mask2D
1415
from autoarray import type as ty
1516

@@ -26,11 +27,11 @@ def __init__(
2627
self,
2728
data: Array2D,
2829
noise_map: Optional[Array2D] = None,
29-
psf: Optional[Kernel2D] = None,
30+
psf: Optional[Convolver] = None,
31+
psf_setup_state: bool = False,
3032
noise_covariance_matrix: Optional[np.ndarray] = None,
3133
over_sample_size_lp: Union[int, Array2D] = 4,
3234
over_sample_size_pixelization: Union[int, Array2D] = 4,
33-
disable_fft_pad: bool = True,
3435
use_normalized_psf: Optional[bool] = True,
3536
check_noise_map: bool = True,
3637
sparse_operator: Optional[ImagingSparseOperator] = None,
@@ -78,10 +79,6 @@ def __init__(
7879
over_sample_size_pixelization
7980
How over sampling is performed for the grid which is associated with a pixelization, which is therefore
8081
passed into the calculations performed in the `inversion` module.
81-
disable_fft_pad
82-
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, which places the fewest zeros
83-
around the image. If this is set to `True`, this optimal padding is not performed and the image is used
84-
as-is.
8582
use_normalized_psf
8683
If `True`, the PSF kernel values are rescaled such that they sum to 1.0. This can be important for ensuring
8784
the PSF kernel does not change the overall normalization of the image when it is convolved with it.
@@ -93,50 +90,6 @@ def __init__(
9390
enable this linear algebra formalism for pixelized reconstructions.
9491
"""
9592

96-
self.disable_fft_pad = disable_fft_pad
97-
98-
if psf is not None:
99-
100-
full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask)
101-
102-
if psf is not None and not disable_fft_pad and data.mask.shape != fft_shape:
103-
104-
# If using real-space convolution instead of FFT, enforce odd-odd shapes
105-
if not psf.use_fft:
106-
fft_shape = tuple(s + 1 if s % 2 == 0 else s for s in fft_shape)
107-
108-
logger.info(
109-
f"Imaging data has been trimmed or padded for FFT convolution.\n"
110-
f" - Original shape : {data.mask.shape}\n"
111-
f" - FFT shape : {fft_shape}\n"
112-
f"Padding ensures accurate PSF convolution in Fourier space. "
113-
f"Set `disable_fft_pad=True` in Imaging object to turn off automatic padding."
114-
)
115-
116-
over_sample_size_lp = (
117-
over_sample_util.over_sample_size_convert_to_array_2d_from(
118-
over_sample_size=over_sample_size_lp, mask=data.mask
119-
)
120-
)
121-
over_sample_size_lp = over_sample_size_lp.resized_from(
122-
new_shape=fft_shape, mask_pad_value=1
123-
)
124-
125-
over_sample_size_pixelization = (
126-
over_sample_util.over_sample_size_convert_to_array_2d_from(
127-
over_sample_size=over_sample_size_pixelization, mask=data.mask
128-
)
129-
)
130-
over_sample_size_pixelization = over_sample_size_pixelization.resized_from(
131-
new_shape=fft_shape, mask_pad_value=1
132-
)
133-
134-
data = data.resized_from(new_shape=fft_shape, mask_pad_value=1)
135-
if noise_map is not None:
136-
noise_map = noise_map.resized_from(
137-
new_shape=fft_shape, mask_pad_value=1
138-
)
139-
14093
super().__init__(
14194
data=data,
14295
noise_map=noise_map,
@@ -145,8 +98,6 @@ def __init__(
14598
over_sample_size_pixelization=over_sample_size_pixelization,
14699
)
147100

148-
self.use_normalized_psf = use_normalized_psf
149-
150101
if self.noise_map.native is not None and check_noise_map:
151102
if ((self.noise_map.native <= 0.0) * np.invert(self.noise_map.mask)).any():
152103
zero_entries = np.argwhere(self.noise_map.native <= 0.0)
@@ -163,36 +114,22 @@ def __init__(
163114

164115
if psf is not None:
165116

166-
if not data.mask.is_all_false:
117+
if use_normalized_psf:
167118

168-
image_mask = data.mask
169-
blurring_mask = data.mask.derive_mask.blurring_from(
170-
kernel_shape_native=psf.shape_native
119+
psf.kernel._array = np.divide(
120+
psf.kernel._array, np.sum(psf.kernel._array)
171121
)
172122

173-
else:
123+
if psf_setup_state:
174124

175-
image_mask = None
176-
blurring_mask = None
125+
state = ConvolverState(kernel=psf.kernel, mask=self.data.mask)
177126

178-
psf = Kernel2D.no_mask(
179-
values=psf.native._array,
180-
pixel_scales=psf.pixel_scales,
181-
normalize=use_normalized_psf,
182-
image_mask=image_mask,
183-
blurring_mask=blurring_mask,
184-
mask_shape=mask_shape,
185-
full_shape=full_shape,
186-
fft_shape=fft_shape,
187-
)
127+
psf = Convolver(
128+
kernel=psf.kernel, state=state, normalize=use_normalized_psf
129+
)
188130

189131
self.psf = psf
190132

191-
if psf is not None:
192-
if not psf.use_fft:
193-
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
194-
raise exc.KernelException("Kernel2D Kernel2D must be odd")
195-
196133
self.grids = GridsDataset(
197134
mask=self.data.mask,
198135
over_sample_size_lp=self.over_sample_size_lp,
@@ -272,14 +209,17 @@ def from_fits(
272209
)
273210

274211
if psf_path is not None:
275-
psf = Kernel2D.from_fits(
212+
kernel = Array2D.from_fits(
276213
file_path=psf_path,
277214
hdu=psf_hdu,
278215
pixel_scales=pixel_scales,
279-
normalize=False,
216+
)
217+
psf = Convolver(
218+
kernel=kernel,
280219
)
281220

282221
else:
222+
kernel = None
283223
psf = None
284224

285225
return Imaging(
@@ -292,7 +232,7 @@ def from_fits(
292232
over_sample_size_pixelization=over_sample_size_pixelization,
293233
)
294234

295-
def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging":
235+
def apply_mask(self, mask: Mask2D) -> "Imaging":
296236
"""
297237
Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other
298238
quantities one-by-one.
@@ -340,10 +280,10 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging":
340280
data=data,
341281
noise_map=noise_map,
342282
psf=self.psf,
283+
psf_setup_state=True,
343284
noise_covariance_matrix=noise_covariance_matrix,
344285
over_sample_size_lp=over_sample_size_lp,
345286
over_sample_size_pixelization=over_sample_size_pixelization,
346-
disable_fft_pad=disable_fft_pad,
347287
)
348288

349289
logger.info(
@@ -356,7 +296,6 @@ def apply_noise_scaling(
356296
self,
357297
mask: Mask2D,
358298
noise_value: float = 1e8,
359-
disable_fft_pad: bool = False,
360299
signal_to_noise_value: Optional[float] = None,
361300
should_zero_data: bool = True,
362301
) -> "Imaging":
@@ -423,7 +362,6 @@ def apply_noise_scaling(
423362
noise_covariance_matrix=self.noise_covariance_matrix,
424363
over_sample_size_lp=self.over_sample_size_lp,
425364
over_sample_size_pixelization=self.over_sample_size_pixelization,
426-
disable_fft_pad=disable_fft_pad,
427365
check_noise_map=False,
428366
)
429367

@@ -437,7 +375,6 @@ def apply_over_sampling(
437375
self,
438376
over_sample_size_lp: Union[int, Array2D] = None,
439377
over_sample_size_pixelization: Union[int, Array2D] = None,
440-
disable_fft_pad: bool = False,
441378
) -> "AbstractDataset":
442379
"""
443380
Apply new over sampling objects to the grid and grid pixelization of the dataset.
@@ -467,7 +404,6 @@ def apply_over_sampling(
467404
over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp,
468405
over_sample_size_pixelization=over_sample_size_pixelization
469406
or self.over_sample_size_pixelization,
470-
disable_fft_pad=disable_fft_pad,
471407
check_noise_map=False,
472408
)
473409

@@ -476,7 +412,6 @@ def apply_over_sampling(
476412
def apply_sparse_operator(
477413
self,
478414
batch_size: int = 128,
479-
disable_fft_pad: bool = False,
480415
):
481416
"""
482417
The sparse linear algebra formalism precomputes the convolution of every pair of masked
@@ -493,11 +428,6 @@ def apply_sparse_operator(
493428
batch_size
494429
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
495430
which can be reduced to produce lower memory usage at the cost of speed
496-
disable_fft_pad
497-
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
498-
which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
499-
performed and the image is used as-is. This is normally used to avoid repadding data that has already been
500-
padded.
501431
use_jax
502432
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
503433
"""
@@ -510,7 +440,7 @@ def apply_sparse_operator(
510440
inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf(
511441
data=self.data,
512442
noise_map=self.noise_map,
513-
psf=self.psf.native,
443+
psf=self.psf.kernel.native,
514444
batch_size=batch_size,
515445
)
516446
)
@@ -522,14 +452,12 @@ def apply_sparse_operator(
522452
noise_covariance_matrix=self.noise_covariance_matrix,
523453
over_sample_size_lp=self.over_sample_size_lp,
524454
over_sample_size_pixelization=self.over_sample_size_pixelization,
525-
disable_fft_pad=disable_fft_pad,
526455
check_noise_map=False,
527456
sparse_operator=sparse_operator,
528457
)
529458

530459
def apply_sparse_operator_cpu(
531460
self,
532-
disable_fft_pad: bool = False,
533461
):
534462
"""
535463
The sparse linear algebra formalism precomputes the convolution of every pair of masked
@@ -545,12 +473,7 @@ def apply_sparse_operator_cpu(
545473
-------
546474
batch_size
547475
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
548-
which can be reduced to produce lower memory usage at the cost of speed
549-
disable_fft_pad
550-
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
551-
which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
552-
performed and the image is used as-is. This is normally used to avoid repadding data that has already been
553-
padded.
476+
which can be reduced to produce lower memory usage at the cost of speed.
554477
use_jax
555478
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
556479
"""
@@ -575,7 +498,7 @@ def apply_sparse_operator_cpu(
575498
lengths,
576499
) = inversion_imaging_numba_util.psf_precision_operator_sparse_from(
577500
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
578-
kernel_native=np.array(self.psf.native.array).astype("float64"),
501+
kernel_native=np.array(self.psf.kernel.native.array).astype("float64"),
579502
native_index_for_slim_index=np.array(
580503
self.mask.derive_indexes.native_for_slim
581504
).astype("int"),
@@ -597,7 +520,6 @@ def apply_sparse_operator_cpu(
597520
noise_covariance_matrix=self.noise_covariance_matrix,
598521
over_sample_size_lp=self.over_sample_size_lp,
599522
over_sample_size_pixelization=self.over_sample_size_pixelization,
600-
disable_fft_pad=disable_fft_pad,
601523
check_noise_map=False,
602524
sparse_operator=sparse_operator,
603525
)
@@ -633,7 +555,7 @@ def output_to_fits(
633555
self.data.output_to_fits(file_path=data_path, overwrite=overwrite)
634556

635557
if self.psf is not None and psf_path is not None:
636-
self.psf.output_to_fits(file_path=psf_path, overwrite=overwrite)
558+
self.psf.kernel.output_to_fits(file_path=psf_path, overwrite=overwrite)
637559

638560
if self.noise_map is not None and noise_map_path is not None:
639561
self.noise_map.output_to_fits(file_path=noise_map_path, overwrite=overwrite)

autoarray/dataset/imaging/simulator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from autoarray.dataset.imaging.dataset import Imaging
66
from autoarray.structures.arrays.uniform_2d import Array2D
7-
from autoarray.structures.arrays.kernel_2d import Kernel2D
7+
from autoarray.operators.convolver import Convolver
88
from autoarray.mask.mask_2d import Mask2D
99

1010
from autoarray import exc
@@ -19,7 +19,7 @@ def __init__(
1919
exposure_time: float,
2020
background_sky_level: float = 0.0,
2121
subtract_background_sky: bool = True,
22-
psf: Kernel2D = None,
22+
psf: Convolver = None,
2323
use_real_space_convolution: bool = True,
2424
normalize_psf: bool = True,
2525
add_poisson_noise_to_data: bool = True,
@@ -95,7 +95,7 @@ def __init__(
9595
psf = psf.normalized
9696
self.psf = psf
9797
else:
98-
self.psf = Kernel2D.no_blur(pixel_scales=1.0)
98+
self.psf = Convolver.no_blur(pixel_scales=1.0)
9999

100100
self.use_real_space_convolution = use_real_space_convolution
101101
self.exposure_time = exposure_time
@@ -192,13 +192,12 @@ def via_image_from(
192192
psf=self.psf,
193193
noise_map=noise_map,
194194
check_noise_map=False,
195-
disable_fft_pad=True,
196195
)
197196

198197
if over_sample_size is not None:
199198

200199
dataset = dataset.apply_over_sampling(
201-
over_sample_size_lp=over_sample_size.native, disable_fft_pad=True
200+
over_sample_size_lp=over_sample_size.native,
202201
)
203202

204203
return dataset

autoarray/dataset/plot/imaging_plotters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def figures_2d(
9393

9494
if psf:
9595
self.mat_plot_2d.plot_array(
96-
array=self.dataset.psf,
96+
array=self.dataset.psf.kernel,
9797
visuals_2d=self.visuals_2d,
9898
auto_labels=AutoLabels(
9999
title=title_str or f"Point Spread Function",

0 commit comments

Comments
 (0)