@@ -24,6 +24,8 @@ def __init__(
2424 header = None ,
2525 normalize : bool = False ,
2626 store_native : bool = False ,
27+ image_mask = None ,
28+ blurring_mask = None ,
2729 * args ,
2830 ** kwargs ,
2931 ):
@@ -56,6 +58,25 @@ def __init__(
5658
5759 self .stored_native = self .native
5860
61+ self .slim_to_native_tuple = None
62+
63+ if image_mask is not None :
64+
65+ slim_to_native = image_mask .derive_indexes .native_for_slim .astype ("int32" )
66+ self .slim_to_native_tuple = (slim_to_native [:, 0 ], slim_to_native [:, 1 ])
67+
68+ self .slim_to_native_blurring_tuple = None
69+
70+ if blurring_mask is not None :
71+
72+ slim_to_native_blurring = (
73+ blurring_mask .derive_indexes .native_for_slim .astype ("int32" )
74+ )
75+ self .slim_to_native_blurring_tuple = (
76+ slim_to_native_blurring [:, 0 ],
77+ slim_to_native_blurring [:, 1 ],
78+ )
79+
5980 @classmethod
6081 def no_mask (
6182 cls ,
@@ -64,6 +85,8 @@ def no_mask(
6485 shape_native : Tuple [int , int ] = None ,
6586 origin : Tuple [float , float ] = (0.0 , 0.0 ),
6687 normalize : bool = False ,
88+ image_mask = None ,
89+ blurring_mask = None ,
6790 ):
6891 """
6992 Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically
@@ -91,7 +114,13 @@ def no_mask(
91114 pixel_scales = pixel_scales ,
92115 origin = origin ,
93116 )
94- return Kernel2D (values = values , mask = values .mask , normalize = normalize )
117+ return Kernel2D (
118+ values = values ,
119+ mask = values .mask ,
120+ normalize = normalize ,
121+ image_mask = image_mask ,
122+ blurring_mask = blurring_mask ,
123+ )
95124
96125 @classmethod
97126 def full (
@@ -540,29 +569,41 @@ def convolve_image(self, image, blurring_image, jax_method="direct"):
540569 kernels that are more than about 5x5. Default is `fft`.
541570 """
542571
543- slim_to_native = jnp .nonzero (
544- jnp .logical_not (image .mask .array ), size = image .shape [0 ]
545- )
546- slim_to_native_blurring = jnp .nonzero (
547- jnp .logical_not (blurring_image .mask .array ), size = blurring_image .shape [0 ]
548- )
572+ slim_to_native_tuple = self .slim_to_native_tuple
573+ slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
549574
550- expanded_array_native = jnp . zeros ( image . mask . shape )
575+ if slim_to_native_tuple is None :
551576
552- expanded_array_native = expanded_array_native .at [slim_to_native ].set (
553- image .array
577+ slim_to_native_tuple = jnp .nonzero (
578+ jnp .logical_not (image .mask .array ), size = image .shape [0 ]
579+ )
580+
581+ if slim_to_native_blurring_tuple is None :
582+
583+ slim_to_native_blurring_tuple = jnp .nonzero (
584+ jnp .logical_not (blurring_image .mask .array ), size = blurring_image .shape [0 ]
585+ )
586+
587+ # make sure dtype matches what you want
588+ expanded_array_native = jnp .zeros (
589+ image .mask .shape , dtype = jnp .asarray (image .array ).dtype
554590 )
555- expanded_array_native = expanded_array_native .at [slim_to_native_blurring ].set (
556- blurring_image .array
591+
592+ # set using a tuple of index arrays
593+ expanded_array_native = expanded_array_native .at [slim_to_native_tuple ].set (
594+ jnp .asarray (image .array )
557595 )
596+ expanded_array_native = expanded_array_native .at [
597+ slim_to_native_blurring_tuple
598+ ].set (jnp .asarray (blurring_image .array ))
558599
559600 kernel = self .stored_native .array
560601
561602 convolve_native = jax .scipy .signal .convolve (
562603 expanded_array_native , kernel , mode = "same" , method = jax_method
563604 )
564605
565- convolved_array_1d = convolve_native [slim_to_native ]
606+ convolved_array_1d = convolve_native [slim_to_native_tuple ]
566607
567608 return Array2D (values = convolved_array_1d , mask = image .mask )
568609
@@ -583,24 +624,77 @@ def convolve_image_no_blurring(self, image, mask, jax_method="direct"):
583624 kernels that are more than about 5x5. Default is `fft`.
584625 """
585626
586- slim_to_native = jnp .nonzero (jnp .logical_not (mask .array ), size = image .shape [0 ])
627+ slim_to_native_tuple = self .slim_to_native_tuple
628+
629+ if slim_to_native_tuple is None :
630+
631+ slim_to_native_tuple = jnp .nonzero (
632+ jnp .logical_not (mask .array ), size = image .shape [0 ]
633+ )
587634
635+ # make sure dtype matches what you want
588636 expanded_array_native = jnp .zeros (mask .shape )
589637
638+ # set using a tuple of index arrays
590639 if isinstance (image , np .ndarray ) or isinstance (image , jnp .ndarray ):
591- expanded_array_native = expanded_array_native .at [slim_to_native ].set (image )
640+ expanded_array_native = expanded_array_native .at [slim_to_native_tuple ].set (
641+ image
642+ )
592643 else :
593- expanded_array_native = expanded_array_native .at [slim_to_native ].set (
594- image .array
644+ expanded_array_native = expanded_array_native .at [slim_to_native_tuple ].set (
645+ jnp .asarray (image .array )
646+ )
647+
648+ kernel = self .stored_native .array
649+
650+ convolve_native = jax .scipy .signal .convolve (
651+ expanded_array_native , kernel , mode = "same" , method = jax_method
652+ )
653+
654+ convolved_array_1d = convolve_native [slim_to_native_tuple ]
655+
656+ return Array2D (values = convolved_array_1d , mask = mask )
657+
658+ def convolve_image_no_blurring_for_mapping (self , image , mask , jax_method = "direct" ):
659+ """
660+ For a given 1D array and blurring array, convolve the two using this psf.
661+
662+ Parameters
663+ ----------
664+ image
665+ 1D array of the values which are to be blurred with the psf's PSF.
666+ blurring_image
667+ 1D array of the blurring values which blur into the array after PSF convolution.
668+ jax_method
669+ If JAX is enabled this keyword will indicate what method is used for the PSF
670+ convolution. Can be either `direct` to calculate it in real space or `fft`
671+ to calculated it via a fast Fourier transform. `fft` is typically faster for
672+ kernels that are more than about 5x5. Default is `fft`.
673+ """
674+
675+ slim_to_native_tuple = self .slim_to_native_tuple
676+
677+ if slim_to_native_tuple is None :
678+
679+ slim_to_native_tuple = jnp .nonzero (
680+ jnp .logical_not (mask .array ), size = image .shape [0 ]
595681 )
596682
683+ # make sure dtype matches what you want
684+ expanded_array_native = jnp .zeros (mask .shape )
685+
686+ # set using a tuple of index arrays
687+ expanded_array_native = expanded_array_native .at [slim_to_native_tuple ].set (
688+ image
689+ )
690+
597691 kernel = self .stored_native .array
598692
599693 convolve_native = jax .scipy .signal .convolve (
600694 expanded_array_native , kernel , mode = "same" , method = jax_method
601695 )
602696
603- convolved_array_1d = convolve_native [slim_to_native ]
697+ convolved_array_1d = convolve_native [slim_to_native_tuple ]
604698
605699 return Array2D (values = convolved_array_1d , mask = mask )
606700
@@ -612,6 +706,6 @@ def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"):
612706 image
613707 1D array of the values which are to be blurred with the psf's PSF.
614708 """
615- return jax .vmap (self . convolve_image_no_blurring , in_axes = ( 1 , None , None ))(
616- mapping_matrix , mask , jax_method
617- ).T
709+ return jax .vmap (
710+ self . convolve_image_no_blurring_for_mapping , in_axes = ( 1 , None , None )
711+ )( mapping_matrix , mask , jax_method ) .T
0 commit comments