Skip to content

Commit e7fb78c

Browse files
Jammy2211Jammy2211
authored andcommitted
update FFT padding scheme to be optimal
1 parent 6495882 commit e7fb78c

File tree

3 files changed

+203
-20
lines changed

3 files changed

+203
-20
lines changed

autoarray/dataset/imaging/dataset.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import numpy as np
33
from pathlib import Path
4+
import scipy
45
from typing import Optional, Union
56

67
from autoconf import cached_property
@@ -91,9 +92,17 @@ def __init__(
9192

9293
self.pad_for_psf = pad_for_psf
9394

94-
if pad_for_psf and psf is not None:
95+
if pad_for_psf:
96+
full_shape, fft_shape, mask_shape = psf.fft_shape_from(mask=data.mask)
97+
98+
print(data.mask.shape, full_shape, fft_shape, mask_shape, psf.shape_native)
9599

96-
pad_shape = (300, 300)
100+
else:
101+
full_shape = psf.full_shape
102+
fft_shape = psf.fft_shape
103+
mask_shape = psf.mask_shape
104+
105+
if pad_for_psf and psf is not None:
97106

98107
over_sample_size_lp = (
99108
over_sample_util.over_sample_size_convert_to_array_2d_from(
@@ -102,7 +111,7 @@ def __init__(
102111
)
103112
over_sample_size_lp = (
104113
over_sample_size_lp.resized_from(
105-
new_shape=pad_shape, mask_pad_value=1
114+
new_shape=fft_shape, mask_pad_value=1
106115
)
107116
)
108117

@@ -113,16 +122,16 @@ def __init__(
113122
)
114123
over_sample_size_pixelization = (
115124
over_sample_size_pixelization.resized_from(
116-
new_shape=pad_shape, mask_pad_value=1
125+
new_shape=fft_shape, mask_pad_value=1
117126
)
118127
)
119128

120129
data = data.resized_from(
121-
new_shape=pad_shape, mask_pad_value=1
130+
new_shape=fft_shape, mask_pad_value=1
122131
)
123132
if noise_map is not None:
124133
noise_map = noise_map.resized_from(
125-
new_shape=pad_shape, mask_pad_value=1
134+
new_shape=fft_shape, mask_pad_value=1
126135
)
127136
logger.info(
128137
f"The image and noise map of the `Imaging` objected have been padded to the dimensions"
@@ -177,6 +186,9 @@ def __init__(
177186
normalize=use_normalized_psf,
178187
image_mask=image_mask,
179188
blurring_mask=blurring_mask,
189+
mask_shape=mask_shape,
190+
full_shape=full_shape,
191+
fft_shape=fft_shape
180192
)
181193

182194
self.psf = psf

autoarray/inversion/inversion/inversion_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def curvature_matrix_via_mapping_matrix_from(
9393
noise_map
9494
Flattened 1D array of the noise-map used by the inversion during the fit.
9595
"""
96-
array = mapping_matrix / noise_map[:, None]
96+
array = mapping_matrix / jnp.expand_dims(noise_map.array, 1)
9797
curvature_matrix = jnp.dot(array.T, array)
9898

9999
if add_to_curvature_diag and len(no_regularization_index_list) > 0:

autoarray/structures/arrays/kernel_2d.py

Lines changed: 184 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import jax.numpy as jnp
33
import numpy as np
44
from pathlib import Path
5+
import scipy
56
from typing import List, Optional, Tuple, Union
67

78
from autoconf.fitsable import header_obj_from
@@ -26,6 +27,9 @@ def __init__(
2627
store_native: bool = False,
2728
image_mask=None,
2829
blurring_mask=None,
30+
mask_shape=None,
31+
full_shape=None,
32+
fft_shape=None,
2933
*args,
3034
**kwargs,
3135
):
@@ -77,6 +81,19 @@ def __init__(
7781
slim_to_native_blurring[:, 1],
7882
)
7983

84+
self.fft_shape = fft_shape
85+
86+
self.mask_shape = None
87+
self.full_shape = None
88+
self.fft_psf = None
89+
90+
if self.fft_shape is not None:
91+
92+
self.mask_shape = mask_shape
93+
self.full_shape = full_shape
94+
self.fft_psf = jnp.fft.rfft2(self.native.array, s=self.fft_shape)
95+
self.fft_psf_mapping = jnp.expand_dims(self.fft_psf, 2)
96+
8097
@classmethod
8198
def no_mask(
8299
cls,
@@ -88,6 +105,9 @@ def no_mask(
88105
normalize: bool = False,
89106
image_mask=None,
90107
blurring_mask=None,
108+
mask_shape=None,
109+
full_shape=None,
110+
fft_shape=None
91111
):
92112
"""
93113
Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically
@@ -122,6 +142,9 @@ def no_mask(
122142
normalize=normalize,
123143
image_mask=image_mask,
124144
blurring_mask=blurring_mask,
145+
mask_shape=mask_shape,
146+
full_shape=full_shape,
147+
fft_shape=fft_shape
125148
)
126149

127150
@classmethod
@@ -391,6 +414,21 @@ def from_fits(
391414
header=Header(header_sci_obj=header_sci_obj, header_hdu_obj=header_hdu_obj),
392415
)
393416

417+
def fft_shape_from(self, mask):
418+
419+
ys, xs = np.where(~mask)
420+
y_min, y_max = ys.min(), ys.max()
421+
x_min, x_max = xs.min(), xs.max()
422+
423+
(pad_y, pad_x) = self.shape_native
424+
425+
mask_shape = ((y_max + pad_y // 2) - (y_min - pad_y // 2), (x_max + pad_x // 2) - (x_min - pad_x // 2))
426+
427+
full_shape = tuple(s1 + s2 - 1 for s1, s2 in zip(mask_shape, self.shape_native))
428+
fft_shape = tuple(scipy.fft.next_fast_len(s, real=True) for s in full_shape)
429+
430+
return full_shape, fft_shape, mask_shape
431+
394432
def rescaled_with_odd_dimensions_from(
395433
self, rescale_factor: float, normalize: bool = False
396434
) -> "Kernel2D":
@@ -554,7 +592,7 @@ def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D:
554592

555593
return Array2D(values=convolved_array_1d, mask=mask)
556594

557-
def convolve_image_via_real_space(self, image, blurring_image, jax_method="direct"):
595+
def convolve_image(self, image, blurring_image, jax_method="direct"):
558596
"""
559597
For a given 1D array and blurring array, convolve the two using this psf.
560598
@@ -587,27 +625,105 @@ def convolve_image_via_real_space(self, image, blurring_image, jax_method="direc
587625
)
588626

589627
# make sure dtype matches what you want
590-
expanded_array_native = jnp.zeros(
591-
image.mask.shape, dtype=jnp.asarray(image.array).dtype
628+
image_both_native = jnp.zeros(
629+
image.mask.shape, dtype=image.dtype
592630
)
593631

594632
# set using a tuple of index arrays
595-
expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set(
633+
image_both_native = image_both_native.at[slim_to_native_tuple].set(
596634
jnp.asarray(image.array)
597635
)
598-
expanded_array_native = expanded_array_native.at[
636+
image_both_native = image_both_native.at[
599637
slim_to_native_blurring_tuple
600638
].set(jnp.asarray(blurring_image.array))
601639

602-
kernel = self.stored_native.array
640+
# FFT the combined image
641+
fft_image_native = jnp.fft.rfft2(image_both_native, s=self.fft_shape, axes=(0, 1))
603642

604-
convolve_native = jax.scipy.signal.convolve(
605-
expanded_array_native, kernel, mode="same", method=jax_method
643+
# Multiply by PSF in Fourier space and invert
644+
blurred_image_full = jnp.fft.irfft2(self.fft_psf * fft_image_native, s=self.fft_shape, axes=(0, 1))
645+
646+
# Crop back to mask_shape
647+
start_indices = tuple((full_size - out_size) // 2 for full_size, out_size in zip(self.full_shape, self.mask_shape))
648+
out_shape_full = self.mask_shape
649+
blurred_image_native = jax.lax.dynamic_slice(blurred_image_full, start_indices, out_shape_full)
650+
651+
return Array2D(values=blurred_image_native[slim_to_native_tuple], mask=image.mask)
652+
653+
def convolve_mapping_matrix(
654+
self,
655+
mapping_matrix,
656+
mask,
657+
blurring_mapping_matrix=None,
658+
jax_method="direct",
659+
):
660+
"""
661+
Convolve a source-pixel mapping matrix with this PSF in Fourier space.
662+
Also supports a blurring mapping matrix, which is added in the same way as blurring_image.
663+
664+
Parameters
665+
----------
666+
mapping_matrix : (N_masked_pixels, N_src)
667+
Mapping matrix of unmasked pixels to source pixels.
668+
mask : Mask
669+
Mask object with slim-to-native mapping.
670+
blurring_mapping_matrix : (N_blurring_pixels, N_src) or None
671+
Mapping matrix for the blurring grid (outside the mask core).
672+
If provided, this is scattered into native space and added to the main mapping matrix.
673+
jax_method : str
674+
Currently unused, placeholder for different convolution backends.
675+
676+
Returns
677+
-------
678+
(N_masked_pixels, N_src)
679+
Blurred mapping matrix in slim form (only unmasked pixels).
680+
"""
681+
682+
slim_to_native_tuple = self.slim_to_native_tuple
683+
if slim_to_native_tuple is None:
684+
slim_to_native_tuple = jnp.nonzero(
685+
jnp.logical_not(mask.array), size=mask.shape[0]
686+
)
687+
688+
n_src = mapping_matrix.shape[1]
689+
690+
# allocate full native + source dimension
691+
mapping_matrix_native = jnp.zeros(
692+
mask.shape + (n_src,), dtype=mapping_matrix.dtype
606693
)
607694

608-
convolved_array_1d = convolve_native[slim_to_native_tuple]
695+
# scatter main mapping matrix
696+
mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set(
697+
mapping_matrix
698+
)
609699

610-
return Array2D(values=convolved_array_1d, mask=image.mask)
700+
# optionally scatter blurring mapping matrix
701+
if blurring_mapping_matrix is not None:
702+
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
703+
mapping_matrix_native = mapping_matrix_native.at[
704+
slim_to_native_blurring_tuple
705+
].set(blurring_mapping_matrix)
706+
707+
# FFT convolution
708+
fft_mapping_matrix_native = jnp.fft.rfft2(
709+
mapping_matrix_native, s=self.fft_shape, axes=(0, 1)
710+
)
711+
blurred_mapping_matrix_full = jnp.fft.irfft2(
712+
self.fft_psf_mapping * fft_mapping_matrix_native, s=self.fft_shape, axes=(0, 1)
713+
)
714+
715+
# crop back
716+
start_indices = tuple(
717+
(full_size - out_size) // 2
718+
for full_size, out_size in zip(self.full_shape, self.mask_shape)
719+
) + (0,)
720+
out_shape_full = self.mask_shape + (blurred_mapping_matrix_full.shape[2],)
721+
blurred_mapping_matrix_native = jax.lax.dynamic_slice(
722+
blurred_mapping_matrix_full, start_indices, out_shape_full
723+
)
724+
725+
# return slim form
726+
return blurred_mapping_matrix_native[slim_to_native_tuple]
611727

612728
def convolve_image_no_blurring(self, image, mask, jax_method="direct"):
613729
"""
@@ -657,7 +773,7 @@ def convolve_image_no_blurring(self, image, mask, jax_method="direct"):
657773

658774
return Array2D(values=convolved_array_1d, mask=mask)
659775

660-
def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct"):
776+
def convolve_image_no_blurring_for_mapping_via_real_space(self, image, mask, jax_method="direct"):
661777
"""
662778
For a given 1D array and blurring array, convolve the two using this psf.
663779
@@ -700,7 +816,62 @@ def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct
700816

701817
return Array2D(values=convolved_array_1d, mask=mask)
702818

703-
def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"):
819+
def convolve_image_via_real_space(self, image, blurring_image, jax_method="direct"):
820+
"""
821+
For a given 1D array and blurring array, convolve the two using this psf.
822+
823+
Parameters
824+
----------
825+
image
826+
1D array of the values which are to be blurred with the psf's PSF.
827+
blurring_image
828+
1D array of the blurring values which blur into the array after PSF convolution.
829+
jax_method
830+
If JAX is enabled this keyword will indicate what method is used for the PSF
831+
convolution. Can be either `direct` to calculate it in real space or `fft`
832+
to calculated it via a fast Fourier transform. `fft` is typically faster for
833+
kernels that are more than about 5x5. Default is `fft`.
834+
"""
835+
836+
slim_to_native_tuple = self.slim_to_native_tuple
837+
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
838+
839+
if slim_to_native_tuple is None:
840+
841+
slim_to_native_tuple = jnp.nonzero(
842+
jnp.logical_not(image.mask.array), size=image.shape[0]
843+
)
844+
845+
if slim_to_native_blurring_tuple is None:
846+
847+
slim_to_native_blurring_tuple = jnp.nonzero(
848+
jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0]
849+
)
850+
851+
# make sure dtype matches what you want
852+
expanded_array_native = jnp.zeros(
853+
image.mask.shape, dtype=jnp.asarray(image.array).dtype
854+
)
855+
856+
# set using a tuple of index arrays
857+
expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set(
858+
jnp.asarray(image.array)
859+
)
860+
expanded_array_native = expanded_array_native.at[
861+
slim_to_native_blurring_tuple
862+
].set(jnp.asarray(blurring_image.array))
863+
864+
kernel = self.stored_native.array
865+
866+
convolve_native = jax.scipy.signal.convolve(
867+
expanded_array_native, kernel, mode="same", method=jax_method
868+
)
869+
870+
convolved_array_1d = convolve_native[slim_to_native_tuple]
871+
872+
return Array2D(values=convolved_array_1d, mask=image.mask)
873+
874+
def convolve_mapping_matrix_via_real_space(self, mapping_matrix, mask, jax_method="direct"):
704875
"""For a given 1D array and blurring array, convolve the two using this psf.
705876
706877
Parameters
@@ -709,5 +880,5 @@ def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"):
709880
1D array of the values which are to be blurred with the psf's PSF.
710881
"""
711882
return jax.vmap(
712-
self.convolve_image_no_blurring_for_mapping, in_axes=(1, None, None)
883+
self.convolve_image_no_blurring_for_mapping_via_real_space, in_axes=(1, None, None)
713884
)(mapping_matrix, mask, jax_method).T

0 commit comments

Comments
 (0)