@@ -122,25 +122,6 @@ def __init__(
122122
123123 self .stored_native = self .native
124124
125- self .slim_to_native_tuple = None
126-
127- if image_mask is not None :
128-
129- slim_to_native = image_mask .derive_indexes .native_for_slim .astype ("int32" )
130- self .slim_to_native_tuple = (slim_to_native [:, 0 ], slim_to_native [:, 1 ])
131-
132- self .slim_to_native_blurring_tuple = None
133-
134- if blurring_mask is not None :
135-
136- slim_to_native_blurring = (
137- blurring_mask .derive_indexes .native_for_slim .astype ("int32" )
138- )
139- self .slim_to_native_blurring_tuple = (
140- slim_to_native_blurring [:, 0 ],
141- slim_to_native_blurring [:, 1 ],
142- )
143-
144125 self .fft_shape = fft_shape
145126
146127 self .mask_shape = None
@@ -585,18 +566,6 @@ def mapping_matrix_native_from(
585566 Contains contributions from both the main mapping matrix and, if provided,
586567 the blurring mapping matrix.
587568 """
588- slim_to_native_tuple = self .slim_to_native_tuple
589- if slim_to_native_tuple is None :
590- mask_flat = xp .logical_not (mask .array )
591-
592- if xp .__name__ .startswith ("jax" ):
593- slim_to_native_tuple = xp .nonzero (
594- mask_flat , size = mapping_matrix .shape [0 ]
595- )
596- else :
597- slim_to_native = mask .derive_indexes .native_for_slim .astype ("int32" )
598- slim_to_native_tuple = (slim_to_native [:, 0 ], slim_to_native [:, 1 ])
599-
600569 n_src = mapping_matrix .shape [1 ]
601570
602571 # Allocate full native grid (ny, nx, n_src)
@@ -606,43 +575,22 @@ def mapping_matrix_native_from(
606575
607576 # Scatter main mapping matrix into native cube
608577 if xp .__name__ .startswith ("jax" ):
609- mapping_matrix_native = mapping_matrix_native .at [slim_to_native_tuple ].set (
578+ mapping_matrix_native = mapping_matrix_native .at [mask . slim_to_native_tuple ].set (
610579 mapping_matrix
611580 )
612581 else :
613- mapping_matrix_native [slim_to_native_tuple ] = mapping_matrix
582+ mapping_matrix_native [mask .slim_to_native_tuple ] = mapping_matrix
583+
614584 # Optionally scatter blurring mapping matrix
585+
615586 if blurring_mapping_matrix is not None :
616- slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
617-
618- if slim_to_native_blurring_tuple is None :
619- if blurring_mask is None :
620- raise ValueError (
621- "blurring_mask must be provided if blurring_mapping_matrix is given "
622- "and slim_to_native_blurring_tuple is None."
623- )
624-
625- if xp .__name__ .startswith ("jax" ):
626- mask_flat = xp .logical_not (blurring_mask .array )
627- slim_to_native_blurring_tuple = xp .nonzero (
628- mask_flat ,
629- size = blurring_mapping_matrix .shape [0 ],
630- )
631- else :
632- slim_to_native_blurring = (
633- blurring_mask .derive_indexes .native_for_slim .astype ("int32" )
634- )
635- slim_to_native_blurring_tuple = (
636- slim_to_native_blurring [:, 0 ],
637- slim_to_native_blurring [:, 1 ],
638- )
639587
640588 if xp .__name__ .startswith ("jax" ):
641589 mapping_matrix_native = mapping_matrix_native .at [
642- slim_to_native_blurring_tuple
590+ blurring_mask . slim_to_native_tuple
643591 ].set (blurring_mapping_matrix )
644592 else :
645- mapping_matrix_native [slim_to_native_blurring_tuple ] = (
593+ mapping_matrix_native [blurring_mask . slim_to_native_tuple ] = (
646594 blurring_mapping_matrix
647595 )
648596
@@ -722,31 +670,17 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np
722670 mask_shape = self .mask_shape
723671 fft_psf = self .fft_psf
724672
725- slim_to_native_tuple = self .slim_to_native_tuple
726- slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
727-
728- if slim_to_native_tuple is None :
729-
730- mask_flat = xp .logical_not (image .mask .array )
731- slim_to_native_tuple = xp .nonzero (mask_flat , size = image .shape [0 ])
732-
733673 # start with native image padded with zeros
734674 image_both_native = xp .zeros (image .mask .shape , dtype = image .dtype )
735675
736- image_both_native = image_both_native .at [slim_to_native_tuple ].set (
676+ image_both_native = image_both_native .at [image . mask . slim_to_native_tuple ].set (
737677 xp .asarray (image .array )
738678 )
739679
740680 # add blurring contribution if provided
741681 if blurring_image is not None :
742- if slim_to_native_blurring_tuple is None :
743-
744- mask_flat = xp .logical_not (blurring_image .mask .array )
745- slim_to_native_blurring_tuple = xp .nonzero (
746- mask_flat , size = blurring_image .shape [0 ]
747- )
748682
749- image_both_native = image_both_native .at [slim_to_native_blurring_tuple ].set (
683+ image_both_native = image_both_native .at [blurring_image . mask . slim_to_native_tuple ].set (
750684 xp .asarray (blurring_image .array )
751685 )
752686
@@ -777,7 +711,7 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np
777711 )
778712
779713 blurred_image = Array2D (
780- values = blurred_image_native [slim_to_native_tuple ], mask = image .mask
714+ values = blurred_image_native [image . mask . slim_to_native_tuple ], mask = image .mask
781715 )
782716
783717 if self .fft_shape is None :
@@ -880,13 +814,6 @@ def convolved_mapping_matrix_from(
880814 mask_shape = self .mask_shape
881815 fft_psf_mapping = self .fft_psf_mapping
882816
883- slim_to_native_tuple = self .slim_to_native_tuple
884-
885- if slim_to_native_tuple is None :
886-
887- mask_flat = xp .logical_not (mask .array )
888- slim_to_native_tuple = xp .nonzero (mask_flat , size = mapping_matrix .shape [0 ])
889-
890817 mapping_matrix_native = self .mapping_matrix_native_from (
891818 mapping_matrix = mapping_matrix ,
892819 mask = mask ,
@@ -1055,30 +982,18 @@ def convolved_image_via_real_space_from(
1055982
1056983 import jax
1057984
1058- slim_to_native_tuple = self .slim_to_native_tuple
1059- slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
1060-
1061- if slim_to_native_tuple is None :
1062- mask_flat = xp .logical_not (image .mask .array )
1063- slim_to_native_tuple = xp .nonzero (mask_flat , size = image .shape [0 ])
1064-
1065985 # start with native array padded with zeros
1066986 image_native = xp .zeros (image .mask .shape , dtype = xp .asarray (image .array ).dtype )
1067987
1068988 # set image pixels
1069- image_native = image_native .at [slim_to_native_tuple ].set (
989+ image_native = image_native .at [image . mask . slim_to_native_tuple ].set (
1070990 xp .asarray (image .array )
1071991 )
1072992
1073993 # add blurring contribution if provided
1074994 if blurring_image is not None :
1075- if slim_to_native_blurring_tuple is None :
1076995
1077- slim_to_native_blurring_tuple = xp .nonzero (
1078- mask_flat ,
1079- size = blurring_image .shape [0 ],
1080- )
1081- image_native = image_native .at [slim_to_native_blurring_tuple ].set (
996+ image_native = image_native .at [blurring_image .mask .slim_to_native_tuple ].set (
1082997 xp .asarray (blurring_image .array )
1083998 )
1084999 else :
@@ -1094,7 +1009,7 @@ def convolved_image_via_real_space_from(
10941009 image_native , kernel , mode = "same" , method = jax_method
10951010 )
10961011
1097- convolved_array_1d = convolve_native [slim_to_native_tuple ]
1012+ convolved_array_1d = convolve_native [image . mask . slim_to_native_tuple ]
10981013
10991014 return Array2D (values = convolved_array_1d , mask = image .mask )
11001015
@@ -1146,16 +1061,6 @@ def convolved_mapping_matrix_via_real_space_from(
11461061
11471062 import jax
11481063
1149- slim_to_native_tuple = self .slim_to_native_tuple
1150-
1151- if slim_to_native_tuple is None :
1152-
1153- mask_flat = xp .logical_not (mask .array )
1154- slim_to_native_tuple = xp .nonzero (
1155- mask_flat ,
1156- size = mapping_matrix .shape [0 ],
1157- )
1158-
11591064 mapping_matrix_native = self .mapping_matrix_native_from (
11601065 mapping_matrix = mapping_matrix ,
11611066 mask = mask ,
@@ -1174,7 +1079,7 @@ def convolved_mapping_matrix_via_real_space_from(
11741079 )
11751080
11761081 # return slim form
1177- return blurred_mapping_matrix_native [slim_to_native_tuple ]
1082+ return blurred_mapping_matrix_native [mask . slim_to_native_tuple ]
11781083
11791084 def convolved_image_via_real_space_np_from (
11801085 self , image : np .ndarray , blurring_image : Optional [np .ndarray ] = None , xp = np
@@ -1207,32 +1112,16 @@ def convolved_image_via_real_space_np_from(
12071112
12081113 from scipy .signal import convolve as scipy_convolve
12091114
1210- slim_to_native_tuple = self .slim_to_native_tuple
1211- slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
1212-
1213- if slim_to_native_tuple is None :
1214- slim_to_native = image .mask .derive_indexes .native_for_slim .astype ("int32" )
1215- slim_to_native_tuple = (slim_to_native [:, 0 ], slim_to_native [:, 1 ])
1216-
12171115 # start with native array padded with zeros
12181116 image_native = xp .zeros (image .mask .shape , dtype = xp .asarray (image .array ).dtype )
12191117
12201118 # set image pixels
1221- image_native [slim_to_native_tuple ] = xp .asarray (image .array )
1119+ image_native [image . mask . slim_to_native_tuple ] = xp .asarray (image .array )
12221120
12231121 # add blurring contribution if provided
12241122 if blurring_image is not None :
1225- if slim_to_native_blurring_tuple is None :
1226-
1227- slim_to_native_blurring = (
1228- blurring_image .mask .derive_indexes .native_for_slim .astype ("int32" )
1229- )
1230- slim_to_native_blurring_tuple = (
1231- slim_to_native_blurring [:, 0 ],
1232- slim_to_native_blurring [:, 1 ],
1233- )
12341123
1235- image_native [slim_to_native_blurring_tuple ] = xp .asarray (
1124+ image_native [blurring_image . mask . slim_to_native_tuple ] = xp .asarray (
12361125 blurring_image .array
12371126 )
12381127 else :
@@ -1248,7 +1137,7 @@ def convolved_image_via_real_space_np_from(
12481137 image_native , kernel , mode = "same" , method = "auto"
12491138 )
12501139
1251- convolved_array_1d = convolve_native [slim_to_native_tuple ]
1140+ convolved_array_1d = convolve_native [image . mask . slim_to_native_tuple ]
12521141
12531142 return Array2D (values = convolved_array_1d , mask = image .mask )
12541143
@@ -1290,13 +1179,6 @@ def convolved_mapping_matrix_via_real_space_np_from(
12901179
12911180 from scipy .signal import convolve as scipy_convolve
12921181
1293- slim_to_native_tuple = self .slim_to_native_tuple
1294-
1295- if slim_to_native_tuple is None :
1296-
1297- slim_to_native = mask .derive_indexes .native_for_slim .astype ("int32" )
1298- slim_to_native_tuple = (slim_to_native [:, 0 ], slim_to_native [:, 1 ])
1299-
13001182 mapping_matrix_native = self .mapping_matrix_native_from (
13011183 mapping_matrix = mapping_matrix ,
13021184 mask = mask ,
@@ -1314,4 +1196,4 @@ def convolved_mapping_matrix_via_real_space_np_from(
13141196 )
13151197
13161198 # return slim form
1317- return blurred_mapping_matrix_native [slim_to_native_tuple ]
1199+ return blurred_mapping_matrix_native [mask . slim_to_native_tuple ]
0 commit comments