22import jax .numpy as jnp
33import numpy as np
44from pathlib import Path
5+ import scipy
56from typing import List , Optional , Tuple , Union
67
78from autoconf .fitsable import header_obj_from
@@ -26,6 +27,9 @@ def __init__(
2627 store_native : bool = False ,
2728 image_mask = None ,
2829 blurring_mask = None ,
30+ mask_shape = None ,
31+ full_shape = None ,
32+ fft_shape = None ,
2933 * args ,
3034 ** kwargs ,
3135 ):
@@ -77,6 +81,19 @@ def __init__(
7781 slim_to_native_blurring [:, 1 ],
7882 )
7983
84+ self .fft_shape = fft_shape
85+
86+ self .mask_shape = None
87+ self .full_shape = None
88+ self .fft_psf = None
89+
90+ if self .fft_shape is not None :
91+
92+ self .mask_shape = mask_shape
93+ self .full_shape = full_shape
94+ self .fft_psf = jnp .fft .rfft2 (self .native .array , s = self .fft_shape )
95+ self .fft_psf_mapping = jnp .expand_dims (self .fft_psf , 2 )
96+
8097 @classmethod
8198 def no_mask (
8299 cls ,
@@ -88,6 +105,9 @@ def no_mask(
88105 normalize : bool = False ,
89106 image_mask = None ,
90107 blurring_mask = None ,
108+ mask_shape = None ,
109+ full_shape = None ,
110+ fft_shape = None
91111 ):
92112 """
93113 Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically
@@ -122,6 +142,9 @@ def no_mask(
122142 normalize = normalize ,
123143 image_mask = image_mask ,
124144 blurring_mask = blurring_mask ,
145+ mask_shape = mask_shape ,
146+ full_shape = full_shape ,
147+ fft_shape = fft_shape
125148 )
126149
127150 @classmethod
@@ -391,6 +414,21 @@ def from_fits(
391414 header = Header (header_sci_obj = header_sci_obj , header_hdu_obj = header_hdu_obj ),
392415 )
393416
417+ def fft_shape_from (self , mask ):
418+
419+ ys , xs = np .where (~ mask )
420+ y_min , y_max = ys .min (), ys .max ()
421+ x_min , x_max = xs .min (), xs .max ()
422+
423+ (pad_y , pad_x ) = self .shape_native
424+
425+ mask_shape = ((y_max + pad_y // 2 ) - (y_min - pad_y // 2 ), (x_max + pad_x // 2 ) - (x_min - pad_x // 2 ))
426+
427+ full_shape = tuple (s1 + s2 - 1 for s1 , s2 in zip (mask_shape , self .shape_native ))
428+ fft_shape = tuple (scipy .fft .next_fast_len (s , real = True ) for s in full_shape )
429+
430+ return full_shape , fft_shape , mask_shape
431+
394432 def rescaled_with_odd_dimensions_from (
395433 self , rescale_factor : float , normalize : bool = False
396434 ) -> "Kernel2D" :
@@ -554,7 +592,7 @@ def convolved_array_with_mask_from(self, array: Array2D, mask) -> Array2D:
554592
555593 return Array2D (values = convolved_array_1d , mask = mask )
556594
557- def convolve_image_via_real_space (self , image , blurring_image , jax_method = "direct" ):
595+ def convolve_image (self , image , blurring_image , jax_method = "direct" ):
558596 """
559597 For a given 1D array and blurring array, convolve the two using this psf.
560598
@@ -587,27 +625,105 @@ def convolve_image_via_real_space(self, image, blurring_image, jax_method="direc
587625 )
588626
589627 # make sure dtype matches what you want
590- expanded_array_native = jnp .zeros (
591- image .mask .shape , dtype = jnp . asarray ( image . array ) .dtype
628+ image_both_native = jnp .zeros (
629+ image .mask .shape , dtype = image .dtype
592630 )
593631
594632 # set using a tuple of index arrays
595- expanded_array_native = expanded_array_native .at [slim_to_native_tuple ].set (
633+ image_both_native = image_both_native .at [slim_to_native_tuple ].set (
596634 jnp .asarray (image .array )
597635 )
598- expanded_array_native = expanded_array_native .at [
636+ image_both_native = image_both_native .at [
599637 slim_to_native_blurring_tuple
600638 ].set (jnp .asarray (blurring_image .array ))
601639
602- kernel = self .stored_native .array
640+ # FFT the combined image
641+ fft_image_native = jnp .fft .rfft2 (image_both_native , s = self .fft_shape , axes = (0 , 1 ))
603642
604- convolve_native = jax .scipy .signal .convolve (
605- expanded_array_native , kernel , mode = "same" , method = jax_method
643+ # Multiply by PSF in Fourier space and invert
644+ blurred_image_full = jnp .fft .irfft2 (self .fft_psf * fft_image_native , s = self .fft_shape , axes = (0 , 1 ))
645+
646+ # Crop back to mask_shape
647+ start_indices = tuple ((full_size - out_size ) // 2 for full_size , out_size in zip (self .full_shape , self .mask_shape ))
648+ out_shape_full = self .mask_shape
649+ blurred_image_native = jax .lax .dynamic_slice (blurred_image_full , start_indices , out_shape_full )
650+
651+ return Array2D (values = blurred_image_native [slim_to_native_tuple ], mask = image .mask )
652+
653+ def convolve_mapping_matrix (
654+ self ,
655+ mapping_matrix ,
656+ mask ,
657+ blurring_mapping_matrix = None ,
658+ jax_method = "direct" ,
659+ ):
660+ """
661+ Convolve a source-pixel mapping matrix with this PSF in Fourier space.
662+ Also supports a blurring mapping matrix, which is added in the same way as blurring_image.
663+
664+ Parameters
665+ ----------
666+ mapping_matrix : (N_masked_pixels, N_src)
667+ Mapping matrix of unmasked pixels to source pixels.
668+ mask : Mask
669+ Mask object with slim-to-native mapping.
670+ blurring_mapping_matrix : (N_blurring_pixels, N_src) or None
671+ Mapping matrix for the blurring grid (outside the mask core).
672+ If provided, this is scattered into native space and added to the main mapping matrix.
673+ jax_method : str
674+ Currently unused, placeholder for different convolution backends.
675+
676+ Returns
677+ -------
678+ (N_masked_pixels, N_src)
679+ Blurred mapping matrix in slim form (only unmasked pixels).
680+ """
681+
682+ slim_to_native_tuple = self .slim_to_native_tuple
683+ if slim_to_native_tuple is None :
684+ slim_to_native_tuple = jnp .nonzero (
685+ jnp .logical_not (mask .array ), size = mask .shape [0 ]
686+ )
687+
688+ n_src = mapping_matrix .shape [1 ]
689+
690+ # allocate full native + source dimension
691+ mapping_matrix_native = jnp .zeros (
692+ mask .shape + (n_src ,), dtype = mapping_matrix .dtype
606693 )
607694
608- convolved_array_1d = convolve_native [slim_to_native_tuple ]
695+ # scatter main mapping matrix
696+ mapping_matrix_native = mapping_matrix_native .at [slim_to_native_tuple ].set (
697+ mapping_matrix
698+ )
609699
610- return Array2D (values = convolved_array_1d , mask = image .mask )
700+ # optionally scatter blurring mapping matrix
701+ if blurring_mapping_matrix is not None :
702+ slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
703+ mapping_matrix_native = mapping_matrix_native .at [
704+ slim_to_native_blurring_tuple
705+ ].set (blurring_mapping_matrix )
706+
707+ # FFT convolution
708+ fft_mapping_matrix_native = jnp .fft .rfft2 (
709+ mapping_matrix_native , s = self .fft_shape , axes = (0 , 1 )
710+ )
711+ blurred_mapping_matrix_full = jnp .fft .irfft2 (
712+ self .fft_psf_mapping * fft_mapping_matrix_native , s = self .fft_shape , axes = (0 , 1 )
713+ )
714+
715+ # crop back
716+ start_indices = tuple (
717+ (full_size - out_size ) // 2
718+ for full_size , out_size in zip (self .full_shape , self .mask_shape )
719+ ) + (0 ,)
720+ out_shape_full = self .mask_shape + (blurred_mapping_matrix_full .shape [2 ],)
721+ blurred_mapping_matrix_native = jax .lax .dynamic_slice (
722+ blurred_mapping_matrix_full , start_indices , out_shape_full
723+ )
724+
725+ # return slim form
726+ return blurred_mapping_matrix_native [slim_to_native_tuple ]
611727
612728 def convolve_image_no_blurring (self , image , mask , jax_method = "direct" ):
613729 """
@@ -657,7 +773,7 @@ def convolve_image_no_blurring(self, image, mask, jax_method="direct"):
657773
658774 return Array2D (values = convolved_array_1d , mask = mask )
659775
660- def convolve_image_no_blurring_for_mapping (self , image , mask , jax_method = "direct" ):
776+ def convolve_image_no_blurring_for_mapping_via_real_space (self , image , mask , jax_method = "direct" ):
661777 """
662778 For a given 1D array and blurring array, convolve the two using this psf.
663779
@@ -700,7 +816,62 @@ def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct
700816
701817 return Array2D (values = convolved_array_1d , mask = mask )
702818
703- def convolve_mapping_matrix (self , mapping_matrix , mask , jax_method = "direct" ):
819+ def convolve_image_via_real_space (self , image , blurring_image , jax_method = "direct" ):
820+ """
821+ For a given 1D array and blurring array, convolve the two using this psf.
822+
823+ Parameters
824+ ----------
825+ image
826+ 1D array of the values which are to be blurred with the psf's PSF.
827+ blurring_image
828+ 1D array of the blurring values which blur into the array after PSF convolution.
829+ jax_method
830+ If JAX is enabled this keyword will indicate what method is used for the PSF
831+ convolution. Can be either `direct` to calculate it in real space or `fft`
832+ to calculated it via a fast Fourier transform. `fft` is typically faster for
833+ kernels that are more than about 5x5. Default is `fft`.
834+ """
835+
836+ slim_to_native_tuple = self .slim_to_native_tuple
837+ slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
838+
839+ if slim_to_native_tuple is None :
840+
841+ slim_to_native_tuple = jnp .nonzero (
842+ jnp .logical_not (image .mask .array ), size = image .shape [0 ]
843+ )
844+
845+ if slim_to_native_blurring_tuple is None :
846+
847+ slim_to_native_blurring_tuple = jnp .nonzero (
848+ jnp .logical_not (blurring_image .mask .array ), size = blurring_image .shape [0 ]
849+ )
850+
851+ # make sure dtype matches what you want
852+ expanded_array_native = jnp .zeros (
853+ image .mask .shape , dtype = jnp .asarray (image .array ).dtype
854+ )
855+
856+ # set using a tuple of index arrays
857+ expanded_array_native = expanded_array_native .at [slim_to_native_tuple ].set (
858+ jnp .asarray (image .array )
859+ )
860+ expanded_array_native = expanded_array_native .at [
861+ slim_to_native_blurring_tuple
862+ ].set (jnp .asarray (blurring_image .array ))
863+
864+ kernel = self .stored_native .array
865+
866+ convolve_native = jax .scipy .signal .convolve (
867+ expanded_array_native , kernel , mode = "same" , method = jax_method
868+ )
869+
870+ convolved_array_1d = convolve_native [slim_to_native_tuple ]
871+
872+ return Array2D (values = convolved_array_1d , mask = image .mask )
873+
874+ def convolve_mapping_matrix_via_real_space (self , mapping_matrix , mask , jax_method = "direct" ):
704875 """For a given 1D array and blurring array, convolve the two using this psf.
705876
706877 Parameters
@@ -709,5 +880,5 @@ def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"):
709880 1D array of the values which are to be blurred with the psf's PSF.
710881 """
711882 return jax .vmap (
712- self .convolve_image_no_blurring_for_mapping , in_axes = (1 , None , None )
883+ self .convolve_image_no_blurring_for_mapping_via_real_space , in_axes = (1 , None , None )
713884 )(mapping_matrix , mask , jax_method ).T
0 commit comments