Skip to content

Commit 15e8cf2

Browse files
authored
Merge pull request #189 from Jammy2211/feature/convolve_fft
Feature/convolve fft
2 parents 5d577d4 + 4e5a647 commit 15e8cf2

File tree

16 files changed

+726
-381
lines changed

16 files changed

+726
-381
lines changed

autoarray/config/general.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ jax:
22
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
33
fits:
44
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
5+
psf:
6+
use_fft_default: true # If True, PSFs are convolved using FFTs by default, which is faster and uses less memory in all cases except for very small PSFs, False uses direct convolution.
57
inversion:
68
check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same.
79
use_positive_only_solver: true # If True, inversion's use a positive-only linear algebra solver by default, which is slower but prevents unphysical negative values in the reconstructed solutuion.

autoarray/dataset/imaging/dataset.py

Lines changed: 65 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional, Union
55

66
from autoconf import cached_property
7+
from autoconf import instance
78

89
from autoarray.dataset.abstract.dataset import AbstractDataset
910
from autoarray.dataset.grids import GridsDataset
@@ -29,7 +30,7 @@ def __init__(
2930
noise_covariance_matrix: Optional[np.ndarray] = None,
3031
over_sample_size_lp: Union[int, Array2D] = 4,
3132
over_sample_size_pixelization: Union[int, Array2D] = 4,
32-
pad_for_psf: bool = False,
33+
disable_fft_pad: bool = True,
3334
use_normalized_psf: Optional[bool] = True,
3435
check_noise_map: bool = True,
3536
):
@@ -76,63 +77,59 @@ def __init__(
7677
over_sample_size_pixelization
7778
How over sampling is performed for the grid which is associated with a pixelization, which is therefore
7879
passed into the calculations performed in the `inversion` module.
79-
pad_for_psf
80-
The PSF convolution may extend beyond the edges of the image mask, which can lead to edge effects in the
81-
convolved image. If `True`, the image and noise-map are padded to ensure the PSF convolution does not
82-
extend beyond the edge of the image.
80+
disable_fft_pad
81+
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, which places the fewest zeros
82+
around the image. If this is set to `True`, this optimal padding is not performed and the image is used
83+
as-is.
8384
use_normalized_psf
8485
If `True`, the PSF kernel values are rescaled such that they sum to 1.0. This can be important for ensuring
8586
the PSF kernel does not change the overall normalization of the image when it is convolved with it.
8687
check_noise_map
8788
If True, the noise-map is checked to ensure all values are above zero.
8889
"""
8990

90-
self.unmasked = None
91+
self.disable_fft_pad = disable_fft_pad
9192

92-
self.pad_for_psf = pad_for_psf
93+
if psf is not None:
9394

94-
if pad_for_psf and psf is not None:
95-
try:
96-
data.mask.derive_mask.blurring_from(
97-
kernel_shape_native=psf.shape_native
98-
)
99-
except exc.MaskException:
100-
over_sample_size_lp = (
101-
over_sample_util.over_sample_size_convert_to_array_2d_from(
102-
over_sample_size=over_sample_size_lp, mask=data.mask
103-
)
104-
)
105-
over_sample_size_lp = (
106-
over_sample_size_lp.padded_before_convolution_from(
107-
kernel_shape=psf.shape_native, mask_pad_value=1
108-
)
109-
)
95+
full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask)
11096

111-
over_sample_size_pixelization = (
112-
over_sample_util.over_sample_size_convert_to_array_2d_from(
113-
over_sample_size=over_sample_size_pixelization, mask=data.mask
114-
)
115-
)
116-
over_sample_size_pixelization = (
117-
over_sample_size_pixelization.padded_before_convolution_from(
118-
kernel_shape=psf.shape_native, mask_pad_value=1
119-
)
97+
if psf is not None and not disable_fft_pad and data.mask.shape != fft_shape:
98+
99+
# If using real-space convolution instead of FFT, enforce odd-odd shapes
100+
if not psf.use_fft:
101+
fft_shape = tuple(s + 1 if s % 2 == 0 else s for s in fft_shape)
102+
103+
logger.info(
104+
f"Imaging data has been trimmed or padded for FFT convolution.\n"
105+
f" - Original shape : {data.mask.shape}\n"
106+
f" - FFT shape : {fft_shape}\n"
107+
f"Padding ensures accurate PSF convolution in Fourier space. "
108+
f"Set `disable_fft_pad=True` in Imaging object to turn off automatic padding."
109+
)
110+
111+
over_sample_size_lp = (
112+
over_sample_util.over_sample_size_convert_to_array_2d_from(
113+
over_sample_size=over_sample_size_lp, mask=data.mask
120114
)
115+
)
116+
over_sample_size_lp = over_sample_size_lp.resized_from(
117+
new_shape=fft_shape, mask_pad_value=1
118+
)
121119

122-
data = data.padded_before_convolution_from(
123-
kernel_shape=psf.shape_native, mask_pad_value=1
120+
over_sample_size_pixelization = (
121+
over_sample_util.over_sample_size_convert_to_array_2d_from(
122+
over_sample_size=over_sample_size_pixelization, mask=data.mask
124123
)
125-
if noise_map is not None:
126-
noise_map = noise_map.padded_before_convolution_from(
127-
kernel_shape=psf.shape_native, mask_pad_value=1
128-
)
129-
logger.info(
130-
f"The image and noise map of the `Imaging` objected have been padded to the dimensions"
131-
f"{data.shape}. This is because the blurring region around the mask (which defines where"
132-
f"PSF flux may be convolved into the masked region) extended beyond the edge of the image."
133-
f""
134-
f"This can be prevented by using a smaller mask, smaller PSF kernel size or manually padding"
135-
f"the image and noise-map yourself."
124+
)
125+
over_sample_size_pixelization = over_sample_size_pixelization.resized_from(
126+
new_shape=fft_shape, mask_pad_value=1
127+
)
128+
129+
data = data.resized_from(new_shape=fft_shape, mask_pad_value=1)
130+
if noise_map is not None:
131+
noise_map = noise_map.resized_from(
132+
new_shape=fft_shape, mask_pad_value=1
136133
)
137134

138135
super().__init__(
@@ -179,6 +176,9 @@ def __init__(
179176
normalize=use_normalized_psf,
180177
image_mask=image_mask,
181178
blurring_mask=blurring_mask,
179+
mask_shape=mask_shape,
180+
full_shape=full_shape,
181+
fft_shape=fft_shape,
182182
)
183183

184184
self.psf = psf
@@ -337,31 +337,34 @@ def from_fits(
337337
over_sample_size_pixelization=over_sample_size_pixelization,
338338
)
339339

340-
def apply_mask(self, mask: Mask2D) -> "Imaging":
340+
def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging":
341341
"""
342342
Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other
343343
quantities one-by-one.
344344
345-
The original unmasked imaging data is stored as the `self.unmasked` attribute. This is used to ensure that if
346-
the `apply_mask` function is called multiple times, every mask is always applied to the original unmasked
347-
imaging dataset.
345+
The `apply_mask` function cannot be called multiple times, if it is a mask may remove data, therefore
346+
an exception is raised. If you wish to apply a new mask, reload the dataset from .fits files.
348347
349348
Parameters
350349
----------
351350
mask
352351
The 2D mask that is applied to the image.
353352
"""
354-
if self.data.mask.is_all_false:
355-
unmasked_dataset = self
356-
else:
357-
unmasked_dataset = self.unmasked
353+
invalid = np.logical_and(self.data.mask, np.logical_not(mask))
354+
355+
if np.any(invalid):
356+
raise exc.DatasetException(
357+
"The new mask overlaps with pixels that are already unmasked in the dataset. "
358+
"You cannot apply a new mask on top of an existing one. "
359+
"If you wish to apply a different mask, please reload the dataset from .fits files."
360+
)
358361

359-
data = Array2D(values=unmasked_dataset.data.native, mask=mask)
362+
data = Array2D(values=self.data.native, mask=mask)
360363

361-
noise_map = Array2D(values=unmasked_dataset.noise_map.native, mask=mask)
364+
noise_map = Array2D(values=self.noise_map.native, mask=mask)
362365

363-
if unmasked_dataset.noise_covariance_matrix is not None:
364-
noise_covariance_matrix = unmasked_dataset.noise_covariance_matrix
366+
if self.noise_covariance_matrix is not None:
367+
noise_covariance_matrix = self.noise_covariance_matrix
365368

366369
noise_covariance_matrix = np.delete(
367370
noise_covariance_matrix, mask.derive_indexes.masked_slim, 0
@@ -385,11 +388,9 @@ def apply_mask(self, mask: Mask2D) -> "Imaging":
385388
noise_covariance_matrix=noise_covariance_matrix,
386389
over_sample_size_lp=over_sample_size_lp,
387390
over_sample_size_pixelization=over_sample_size_pixelization,
388-
pad_for_psf=True,
391+
disable_fft_pad=disable_fft_pad,
389392
)
390393

391-
dataset.unmasked = unmasked_dataset
392-
393394
logger.info(
394395
f"IMAGING - Data masked, contains a total of {mask.pixels_in_mask} image-pixels"
395396
)
@@ -400,6 +401,7 @@ def apply_noise_scaling(
400401
self,
401402
mask: Mask2D,
402403
noise_value: float = 1e8,
404+
disable_fft_pad: bool = False,
403405
signal_to_noise_value: Optional[float] = None,
404406
should_zero_data: bool = True,
405407
) -> "Imaging":
@@ -455,18 +457,6 @@ def apply_noise_scaling(
455457
else:
456458
data = self.data.native.array
457459

458-
data_unmasked = Array2D.no_mask(
459-
values=data,
460-
shape_native=self.data.shape_native,
461-
pixel_scales=self.data.pixel_scales,
462-
)
463-
464-
noise_map_unmasked = Array2D.no_mask(
465-
values=noise_map,
466-
shape_native=self.noise_map.shape_native,
467-
pixel_scales=self.noise_map.pixel_scales,
468-
)
469-
470460
data = Array2D(values=data, mask=self.data.mask)
471461

472462
noise_map = Array2D(values=noise_map, mask=self.data.mask)
@@ -478,15 +468,10 @@ def apply_noise_scaling(
478468
noise_covariance_matrix=self.noise_covariance_matrix,
479469
over_sample_size_lp=self.over_sample_size_lp,
480470
over_sample_size_pixelization=self.over_sample_size_pixelization,
481-
pad_for_psf=False,
471+
disable_fft_pad=disable_fft_pad,
482472
check_noise_map=False,
483473
)
484474

485-
if self.unmasked is not None:
486-
dataset.unmasked = self.unmasked
487-
dataset.unmasked.data = data_unmasked
488-
dataset.unmasked.noise_map = noise_map_unmasked
489-
490475
logger.info(
491476
f"IMAGING - Data noise scaling applied, a total of {mask.pixels_in_mask} pixels were scaled to large noise values."
492477
)
@@ -497,6 +482,7 @@ def apply_over_sampling(
497482
self,
498483
over_sample_size_lp: Union[int, Array2D] = None,
499484
over_sample_size_pixelization: Union[int, Array2D] = None,
485+
disable_fft_pad: bool = False,
500486
) -> "AbstractDataset":
501487
"""
502488
Apply new over sampling objects to the grid and grid pixelization of the dataset.
@@ -526,7 +512,7 @@ def apply_over_sampling(
526512
over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp,
527513
over_sample_size_pixelization=over_sample_size_pixelization
528514
or self.over_sample_size_pixelization,
529-
pad_for_psf=False,
515+
disable_fft_pad=disable_fft_pad,
530516
check_noise_map=False,
531517
)
532518

autoarray/dataset/imaging/simulator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def via_image_from(
126126
pixel_scales=image.pixel_scales,
127127
)
128128

129-
image = self.psf.convolved_array_from(array=image)
129+
image = self.psf.convolved_image_from(image=image, blurring_image=None)
130130

131131
image = image + background_sky_map
132132

@@ -169,12 +169,16 @@ def via_image_from(
169169
image = Array2D(values=image, mask=mask)
170170

171171
dataset = Imaging(
172-
data=image, psf=self.psf, noise_map=noise_map, check_noise_map=False
172+
data=image,
173+
psf=self.psf,
174+
noise_map=noise_map,
175+
check_noise_map=False,
176+
disable_fft_pad=True,
173177
)
174178

175179
if over_sample_size is not None:
176180
dataset = dataset.apply_over_sampling(
177-
over_sample_size_lp=over_sample_size.native
181+
over_sample_size_lp=over_sample_size.native, disable_fft_pad=True
178182
)
179183

180184
return dataset

autoarray/geometry/geometry_1d.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
from __future__ import annotations
22
import logging
3-
import numpy as np
4-
from typing import TYPE_CHECKING, List, Tuple, Union
5-
6-
if TYPE_CHECKING:
7-
from autoarray.structures.grids.uniform_1d import Grid1D
8-
from autoarray.mask.mask_2d import Mask2D
3+
from typing import Tuple
94

105
from autoarray import type as ty
116

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
9292

9393
return [
9494
(
95-
self.psf.convolve_mapping_matrix(
95+
self.psf.convolved_mapping_matrix_from(
9696
mapping_matrix=linear_obj.mapping_matrix, mask=self.mask
9797
)
9898
if linear_obj.operated_mapping_matrix_override is None
@@ -134,7 +134,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
134134
if linear_func.operated_mapping_matrix_override is not None:
135135
operated_mapping_matrix = linear_func.operated_mapping_matrix_override
136136
else:
137-
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
137+
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
138138
mapping_matrix=linear_func.mapping_matrix,
139139
mask=self.mask,
140140
)
@@ -215,7 +215,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict:
215215
mapper_operated_mapping_matrix_dict = {}
216216

217217
for mapper in self.cls_list_from(cls=AbstractMapper):
218-
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
218+
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
219219
mapping_matrix=mapper.mapping_matrix,
220220
mask=self.mask,
221221
)

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _data_vector_mapper(self) -> np.ndarray:
7373
mapper = mapper_list[i]
7474
param_range = mapper_param_range_list[i]
7575

76-
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
76+
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
7777
mapping_matrix=mapper.mapping_matrix, mask=self.mask
7878
)
7979

@@ -132,7 +132,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
132132
mapper_i = mapper_list[i]
133133
mapper_param_range_i = mapper_param_range_list[i]
134134

135-
operated_mapping_matrix = self.psf.convolve_mapping_matrix(
135+
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
136136
mapping_matrix=mapper_i.mapping_matrix, mask=self.mask
137137
)
138138

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,13 @@ def mapped_reconstructed_data_dict(self) -> Dict[LinearObj, Array2D]:
518518
reconstruction=np.array(reconstruction),
519519
)
520520

521-
mapped_reconstructed_image = self.psf.convolve_image_no_blurring(
522-
image=mapped_reconstructed_image, mask=self.mask
521+
mapped_reconstructed_image = Array2D(
522+
values=mapped_reconstructed_image, mask=self.mask
523+
)
524+
525+
mapped_reconstructed_image = self.psf.convolved_image_from(
526+
image=mapped_reconstructed_image,
527+
blurring_image=None,
523528
).array
524529

525530
mapped_reconstructed_image = Array2D(

autoarray/mask/derive/mask_2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22
import logging
3-
import copy
43
import numpy as np
54
from typing import TYPE_CHECKING, Tuple
65

@@ -10,7 +9,6 @@
109
from autoarray import exc
1110
from autoarray.mask.derive.indexes_2d import DeriveIndexes2D
1211

13-
from autoarray.structures.arrays import array_2d_util
1412
from autoarray.mask import mask_2d_util
1513

1614
logging.basicConfig()

autoarray/mask/mask_2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,9 @@ def unmasked_blurred_array_from(self, padded_array, psf, image_shape) -> Array2D
653653
The 1D unmasked image which is blurred.
654654
"""
655655

656-
blurred_image = psf.convolved_array_from(array=padded_array)
656+
blurred_image = psf.convolved_image_from(
657+
image=padded_array, blurring_image=None
658+
)
657659

658660
return self.trimmed_array_from(
659661
padded_array=blurred_image, image_shape=image_shape

autoarray/operators/mock/mock_psf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ class MockPSF:
22
def __init__(self, operated_mapping_matrix=None):
33
self.operated_mapping_matrix = operated_mapping_matrix
44

5-
def convolve_mapping_matrix(self, mapping_matrix, mask):
5+
def convolved_mapping_matrix_from(self, mapping_matrix, mask):
66
return self.operated_mapping_matrix

0 commit comments

Comments
 (0)