Skip to content

Commit 9cfa1f3

Browse files
Jammy2211Jammy2211
authored andcommitted
fixc quirks in convolver state
1 parent b6f12b1 commit 9cfa1f3

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

autoarray/operators/convolver.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __init__(
6565
fft_shape
6666
The FFT-friendly padded shape for efficient convolution.
6767
"""
68+
if len(kernel) == 1:
69+
kernel = kernel.resized_from(new_shape=(3, 3))
70+
6871
self.kernel = kernel
6972

7073
ys, xs = np.where(~mask)
@@ -83,6 +86,10 @@ def __init__(
8386
)
8487
fft_shape = tuple(scipy.fft.next_fast_len(s, real=True) for s in full_shape)
8588

89+
# make fft_shape odd x odd to avoid wrap-around artefacts with even kernels
90+
# TODO : Fix this so it pads corrrectly internally
91+
fft_shape = tuple(s + 1 if s % 2 == 0 else s for s in fft_shape)
92+
8693
self.fft_shape = fft_shape
8794
self.mask = mask.resized_from(self.fft_shape, pad_value=1)
8895
self.blurring_mask = self.mask.derive_mask.blurring_from(
@@ -159,19 +166,25 @@ def __init__(
159166
self.kernel._array, np.sum(self.kernel._array)
160167
)
161168

162-
if not use_fft:
169+
self._use_fft = use_fft
170+
171+
if not self._use_fft:
163172
if (
164173
self.kernel.shape_native[0] % 2 == 0
165174
or self.kernel.shape_native[1] % 2 == 0
166175
):
167176
raise exc.KernelException("Convolver Convolver must be odd")
168177

169-
self._use_fft = use_fft
170-
171178
self._state = state
172179

173180
def state_from(self, mask):
174181

182+
if (
183+
mask.shape_native[0] != self.kernel.shape_native[0]
184+
or mask.shape_native[1] != self.kernel.shape_native[1]
185+
):
186+
return ConvolverState(kernel=self.kernel, mask=mask)
187+
175188
if self._state is None:
176189
return ConvolverState(kernel=self.kernel, mask=mask)
177190

@@ -184,6 +197,13 @@ def use_fft(self):
184197

185198
return self._use_fft
186199

200+
@property
201+
def normalized(self) -> "Kernel2D":
202+
"""
203+
Normalize the Kernel2D such that its data_vector values sum to unity.
204+
"""
205+
return Convolver(kernel=self.kernel, state=self._state, normalize=True)
206+
187207
@classmethod
188208
def no_blur(cls, pixel_scales):
189209
"""
@@ -836,7 +856,7 @@ def convolved_image_via_real_space_np_from(
836856
state = self.state_from(mask=image.mask)
837857

838858
# start with native array padded with zeros
839-
image_native = xp.zeros(state.mask.shape, dtype=image.array.dtype)
859+
image_native = xp.zeros(state.fft_shape)
840860

841861
# set image pixels
842862
image_native[state.mask.slim_to_native_tuple] = image.array

0 commit comments

Comments
 (0)