@@ -65,6 +65,9 @@ def __init__(
6565 fft_shape
6666 The FFT-friendly padded shape for efficient convolution.
6767 """
68+ if len (kernel ) == 1 :
69+ kernel = kernel .resized_from (new_shape = (3 , 3 ))
70+
6871 self .kernel = kernel
6972
7073 ys , xs = np .where (~ mask )
@@ -83,6 +86,10 @@ def __init__(
8386 )
8487 fft_shape = tuple (scipy .fft .next_fast_len (s , real = True ) for s in full_shape )
8588
89+ # make fft_shape odd x odd to avoid wrap-around artefacts with even kernels
90+ # TODO : Fix this so it pads corrrectly internally
91+ fft_shape = tuple (s + 1 if s % 2 == 0 else s for s in fft_shape )
92+
8693 self .fft_shape = fft_shape
8794 self .mask = mask .resized_from (self .fft_shape , pad_value = 1 )
8895 self .blurring_mask = self .mask .derive_mask .blurring_from (
@@ -159,19 +166,25 @@ def __init__(
159166 self .kernel ._array , np .sum (self .kernel ._array )
160167 )
161168
162- if not use_fft :
169+ self ._use_fft = use_fft
170+
171+ if not self ._use_fft :
163172 if (
164173 self .kernel .shape_native [0 ] % 2 == 0
165174 or self .kernel .shape_native [1 ] % 2 == 0
166175 ):
167176 raise exc .KernelException ("Convolver Convolver must be odd" )
168177
169- self ._use_fft = use_fft
170-
171178 self ._state = state
172179
173180 def state_from (self , mask ):
174181
182+ if (
183+ mask .shape_native [0 ] != self .kernel .shape_native [0 ]
184+ or mask .shape_native [1 ] != self .kernel .shape_native [1 ]
185+ ):
186+ return ConvolverState (kernel = self .kernel , mask = mask )
187+
175188 if self ._state is None :
176189 return ConvolverState (kernel = self .kernel , mask = mask )
177190
@@ -184,6 +197,13 @@ def use_fft(self):
184197
185198 return self ._use_fft
186199
200+ @property
201+ def normalized (self ) -> "Kernel2D" :
202+ """
203+ Normalize the Kernel2D such that its data_vector values sum to unity.
204+ """
205+ return Convolver (kernel = self .kernel , state = self ._state , normalize = True )
206+
187207 @classmethod
188208 def no_blur (cls , pixel_scales ):
189209 """
@@ -836,7 +856,7 @@ def convolved_image_via_real_space_np_from(
836856 state = self .state_from (mask = image .mask )
837857
838858 # start with native array padded with zeros
839- image_native = xp .zeros (state .mask . shape , dtype = image . array . dtype )
859+ image_native = xp .zeros (state .fft_shape )
840860
841861 # set image pixels
842862 image_native [state .mask .slim_to_native_tuple ] = image .array
0 commit comments