99 ImagingSparseOperator ,
1010)
1111from autoarray .structures .arrays .uniform_2d import Array2D
12- from autoarray .structures .arrays .kernel_2d import Kernel2D
12+ from autoarray .operators .convolver import ConvolverState
13+ from autoarray .operators .convolver import Convolver
1314from autoarray .mask .mask_2d import Mask2D
1415from autoarray import type as ty
1516
@@ -26,11 +27,11 @@ def __init__(
2627 self ,
2728 data : Array2D ,
2829 noise_map : Optional [Array2D ] = None ,
29- psf : Optional [Kernel2D ] = None ,
30+ psf : Optional [Convolver ] = None ,
31+ psf_setup_state : bool = False ,
3032 noise_covariance_matrix : Optional [np .ndarray ] = None ,
3133 over_sample_size_lp : Union [int , Array2D ] = 4 ,
3234 over_sample_size_pixelization : Union [int , Array2D ] = 4 ,
33- disable_fft_pad : bool = True ,
3435 use_normalized_psf : Optional [bool ] = True ,
3536 check_noise_map : bool = True ,
3637 sparse_operator : Optional [ImagingSparseOperator ] = None ,
@@ -78,10 +79,6 @@ def __init__(
7879 over_sample_size_pixelization
7980 How over sampling is performed for the grid which is associated with a pixelization, which is therefore
8081 passed into the calculations performed in the `inversion` module.
81- disable_fft_pad
82- The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming, which places the fewest zeros
83- around the image. If this is set to `True`, this optimal padding is not performed and the image is used
84- as-is.
8582 use_normalized_psf
8683 If `True`, the PSF kernel values are rescaled such that they sum to 1.0. This can be important for ensuring
8784 the PSF kernel does not change the overall normalization of the image when it is convolved with it.
@@ -93,50 +90,6 @@ def __init__(
9390 enable this linear algebra formalism for pixelized reconstructions.
9491 """
9592
96- self .disable_fft_pad = disable_fft_pad
97-
98- if psf is not None :
99-
100- full_shape , fft_shape , mask_shape = psf .fft_shape_from (mask = data .mask )
101-
102- if psf is not None and not disable_fft_pad and data .mask .shape != fft_shape :
103-
104- # If using real-space convolution instead of FFT, enforce odd-odd shapes
105- if not psf .use_fft :
106- fft_shape = tuple (s + 1 if s % 2 == 0 else s for s in fft_shape )
107-
108- logger .info (
109- f"Imaging data has been trimmed or padded for FFT convolution.\n "
110- f" - Original shape : { data .mask .shape } \n "
111- f" - FFT shape : { fft_shape } \n "
112- f"Padding ensures accurate PSF convolution in Fourier space. "
113- f"Set `disable_fft_pad=True` in Imaging object to turn off automatic padding."
114- )
115-
116- over_sample_size_lp = (
117- over_sample_util .over_sample_size_convert_to_array_2d_from (
118- over_sample_size = over_sample_size_lp , mask = data .mask
119- )
120- )
121- over_sample_size_lp = over_sample_size_lp .resized_from (
122- new_shape = fft_shape , mask_pad_value = 1
123- )
124-
125- over_sample_size_pixelization = (
126- over_sample_util .over_sample_size_convert_to_array_2d_from (
127- over_sample_size = over_sample_size_pixelization , mask = data .mask
128- )
129- )
130- over_sample_size_pixelization = over_sample_size_pixelization .resized_from (
131- new_shape = fft_shape , mask_pad_value = 1
132- )
133-
134- data = data .resized_from (new_shape = fft_shape , mask_pad_value = 1 )
135- if noise_map is not None :
136- noise_map = noise_map .resized_from (
137- new_shape = fft_shape , mask_pad_value = 1
138- )
139-
14093 super ().__init__ (
14194 data = data ,
14295 noise_map = noise_map ,
@@ -145,8 +98,6 @@ def __init__(
14598 over_sample_size_pixelization = over_sample_size_pixelization ,
14699 )
147100
148- self .use_normalized_psf = use_normalized_psf
149-
150101 if self .noise_map .native is not None and check_noise_map :
151102 if ((self .noise_map .native <= 0.0 ) * np .invert (self .noise_map .mask )).any ():
152103 zero_entries = np .argwhere (self .noise_map .native <= 0.0 )
@@ -163,36 +114,22 @@ def __init__(
163114
164115 if psf is not None :
165116
166- if not data . mask . is_all_false :
117+ if use_normalized_psf :
167118
168- image_mask = data .mask
169- blurring_mask = data .mask .derive_mask .blurring_from (
170- kernel_shape_native = psf .shape_native
119+ psf .kernel ._array = np .divide (
120+ psf .kernel ._array , np .sum (psf .kernel ._array )
171121 )
172122
173- else :
123+ if psf_setup_state :
174124
175- image_mask = None
176- blurring_mask = None
125+ state = ConvolverState (kernel = psf .kernel , mask = self .data .mask )
177126
178- psf = Kernel2D .no_mask (
179- values = psf .native ._array ,
180- pixel_scales = psf .pixel_scales ,
181- normalize = use_normalized_psf ,
182- image_mask = image_mask ,
183- blurring_mask = blurring_mask ,
184- mask_shape = mask_shape ,
185- full_shape = full_shape ,
186- fft_shape = fft_shape ,
187- )
127+ psf = Convolver (
128+ kernel = psf .kernel , state = state , normalize = use_normalized_psf
129+ )
188130
189131 self .psf = psf
190132
191- if psf is not None :
192- if not psf .use_fft :
193- if psf .mask .shape [0 ] % 2 == 0 or psf .mask .shape [1 ] % 2 == 0 :
194- raise exc .KernelException ("Kernel2D Kernel2D must be odd" )
195-
196133 self .grids = GridsDataset (
197134 mask = self .data .mask ,
198135 over_sample_size_lp = self .over_sample_size_lp ,
@@ -272,14 +209,17 @@ def from_fits(
272209 )
273210
274211 if psf_path is not None :
275- psf = Kernel2D .from_fits (
212+ kernel = Array2D .from_fits (
276213 file_path = psf_path ,
277214 hdu = psf_hdu ,
278215 pixel_scales = pixel_scales ,
279- normalize = False ,
216+ )
217+ psf = Convolver (
218+ kernel = kernel ,
280219 )
281220
282221 else :
222+ kernel = None
283223 psf = None
284224
285225 return Imaging (
@@ -292,7 +232,7 @@ def from_fits(
292232 over_sample_size_pixelization = over_sample_size_pixelization ,
293233 )
294234
295- def apply_mask (self , mask : Mask2D , disable_fft_pad : bool = False ) -> "Imaging" :
235+ def apply_mask (self , mask : Mask2D ) -> "Imaging" :
296236 """
297237 Apply a mask to the imaging dataset, whereby the mask is applied to the image data, noise-map and other
298238 quantities one-by-one.
@@ -340,10 +280,10 @@ def apply_mask(self, mask: Mask2D, disable_fft_pad: bool = False) -> "Imaging":
340280 data = data ,
341281 noise_map = noise_map ,
342282 psf = self .psf ,
283+ psf_setup_state = True ,
343284 noise_covariance_matrix = noise_covariance_matrix ,
344285 over_sample_size_lp = over_sample_size_lp ,
345286 over_sample_size_pixelization = over_sample_size_pixelization ,
346- disable_fft_pad = disable_fft_pad ,
347287 )
348288
349289 logger .info (
@@ -356,7 +296,6 @@ def apply_noise_scaling(
356296 self ,
357297 mask : Mask2D ,
358298 noise_value : float = 1e8 ,
359- disable_fft_pad : bool = False ,
360299 signal_to_noise_value : Optional [float ] = None ,
361300 should_zero_data : bool = True ,
362301 ) -> "Imaging" :
@@ -423,7 +362,6 @@ def apply_noise_scaling(
423362 noise_covariance_matrix = self .noise_covariance_matrix ,
424363 over_sample_size_lp = self .over_sample_size_lp ,
425364 over_sample_size_pixelization = self .over_sample_size_pixelization ,
426- disable_fft_pad = disable_fft_pad ,
427365 check_noise_map = False ,
428366 )
429367
@@ -437,7 +375,6 @@ def apply_over_sampling(
437375 self ,
438376 over_sample_size_lp : Union [int , Array2D ] = None ,
439377 over_sample_size_pixelization : Union [int , Array2D ] = None ,
440- disable_fft_pad : bool = False ,
441378 ) -> "AbstractDataset" :
442379 """
443380 Apply new over sampling objects to the grid and grid pixelization of the dataset.
@@ -467,7 +404,6 @@ def apply_over_sampling(
467404 over_sample_size_lp = over_sample_size_lp or self .over_sample_size_lp ,
468405 over_sample_size_pixelization = over_sample_size_pixelization
469406 or self .over_sample_size_pixelization ,
470- disable_fft_pad = disable_fft_pad ,
471407 check_noise_map = False ,
472408 )
473409
@@ -476,7 +412,6 @@ def apply_over_sampling(
476412 def apply_sparse_operator (
477413 self ,
478414 batch_size : int = 128 ,
479- disable_fft_pad : bool = False ,
480415 ):
481416 """
482417 The sparse linear algebra formalism precomputes the convolution of every pair of masked
@@ -493,11 +428,6 @@ def apply_sparse_operator(
493428 batch_size
494429 The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
495430 which can be reduced to produce lower memory usage at the cost of speed
496- disable_fft_pad
497- The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
498- which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
499- performed and the image is used as-is. This is normally used to avoid repadding data that has already been
500- padded.
501431 use_jax
502432 Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
503433 """
@@ -510,7 +440,7 @@ def apply_sparse_operator(
510440 inversion_imaging_util .ImagingSparseOperator .from_noise_map_and_psf (
511441 data = self .data ,
512442 noise_map = self .noise_map ,
513- psf = self .psf .native ,
443+ psf = self .psf .kernel . native ,
514444 batch_size = batch_size ,
515445 )
516446 )
@@ -522,14 +452,12 @@ def apply_sparse_operator(
522452 noise_covariance_matrix = self .noise_covariance_matrix ,
523453 over_sample_size_lp = self .over_sample_size_lp ,
524454 over_sample_size_pixelization = self .over_sample_size_pixelization ,
525- disable_fft_pad = disable_fft_pad ,
526455 check_noise_map = False ,
527456 sparse_operator = sparse_operator ,
528457 )
529458
530459 def apply_sparse_operator_cpu (
531460 self ,
532- disable_fft_pad : bool = False ,
533461 ):
534462 """
535463 The sparse linear algebra formalism precomputes the convolution of every pair of masked
@@ -545,12 +473,7 @@ def apply_sparse_operator_cpu(
545473 -------
546474 batch_size
547475 The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
548- which can be reduced to produce lower memory usage at the cost of speed
549- disable_fft_pad
550- The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
551- which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
552- performed and the image is used as-is. This is normally used to avoid repadding data that has already been
553- padded.
476+ which can be reduced to produce lower memory usage at the cost of speed.
554477 use_jax
555478 Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
556479 """
@@ -575,7 +498,7 @@ def apply_sparse_operator_cpu(
575498 lengths ,
576499 ) = inversion_imaging_numba_util .psf_precision_operator_sparse_from (
577500 noise_map_native = np .array (self .noise_map .native .array ).astype ("float64" ),
578- kernel_native = np .array (self .psf .native .array ).astype ("float64" ),
501+ kernel_native = np .array (self .psf .kernel . native .array ).astype ("float64" ),
579502 native_index_for_slim_index = np .array (
580503 self .mask .derive_indexes .native_for_slim
581504 ).astype ("int" ),
@@ -597,7 +520,6 @@ def apply_sparse_operator_cpu(
597520 noise_covariance_matrix = self .noise_covariance_matrix ,
598521 over_sample_size_lp = self .over_sample_size_lp ,
599522 over_sample_size_pixelization = self .over_sample_size_pixelization ,
600- disable_fft_pad = disable_fft_pad ,
601523 check_noise_map = False ,
602524 sparse_operator = sparse_operator ,
603525 )
@@ -633,7 +555,7 @@ def output_to_fits(
633555 self .data .output_to_fits (file_path = data_path , overwrite = overwrite )
634556
635557 if self .psf is not None and psf_path is not None :
636- self .psf .output_to_fits (file_path = psf_path , overwrite = overwrite )
558+ self .psf .kernel . output_to_fits (file_path = psf_path , overwrite = overwrite )
637559
638560 if self .noise_map is not None and noise_map_path is not None :
639561 self .noise_map .output_to_fits (file_path = noise_map_path , overwrite = overwrite )
0 commit comments