@@ -30,40 +30,77 @@ def __init__(
3030 mask : Mask2D ,
3131 ):
3232 """
33- Compute the padded shapes required for FFT-based convolution with this kernel.
33+ Compute and store the padded shapes and masks required for FFT-based convolution
34+ of masked 2D data with a kernel.
3435
35- FFT convolution requires the input image and kernel to be zero-padded so that
36- the convolution is equivalent to linear convolution (not circular) and to avoid
37- wrap-around artefacts.
36+ FFT convolution operates on fully-sampled rectangular arrays, whereas scientific
37+ imaging data are typically defined only on a subset of pixels via a mask. This
38+ class determines how masked real-space data are embedded into a padded array,
39+ transformed to Fourier space, convolved with a kernel, and transformed back such
40+ that the result is equivalent to linear (not circular) convolution.
3841
39- This method inspects the mask and the kernel shape to determines three key shapes:
42+ The input mask defines which pixels contain valid data and therefore which
43+ regions of the image must be retained when mapping to and from FFT space. The
44+ kernel shape defines how far flux from unmasked pixels can spread into masked
45+ regions during convolution.
4046
41- - ``mask_shape``: the rectangular bounding-box region of the mask that encloses
42- all unmasked (False) pixels, padded by half the kernel size in each direction.
43- This is the minimal region that must be retained for convolution. This is
44- not used or computed outside thi sfunction with the two shapes below
45- used instead.
47+ This initializer inspects the mask and kernel to compute three key array shapes:
4648
47- - ``full_shape``: the "linear convolution shape", equal to
48- ``mask_shape + kernel_shape - 1``. This is the minimal padded size required
49- for an exact linear convolution. This is also not used or computed outside
50- this function.
49+ ``mask_shape``
50+ The minimal rectangular bounding box enclosing all unmasked (False) pixels
51+ in the mask, expanded by half the kernel size in each direction. This is the
52+ smallest region that must be retained to ensure that convolution does not
53+ lose flux near the mask boundary.
5154
52- - ``fft_shape``: the FFT-efficient padded shape, obtained by rounding each
53- dimension of ``full_shape`` up to the next fast length for real FFTs
54- (via ``scipy.fft.next_fast_len``). Using this ensures efficient FFT execution.
55+ ``full_shape``
56+ The minimal array shape required for exact linear convolution, defined as::
57+
58+ full_shape = mask_shape + kernel_shape - 1
59+
60+ Padding to this size guarantees that FFT-based convolution is mathematically
61+ equivalent to direct spatial convolution, with no wrap-around artefacts.
62+
63+ ``fft_shape``
64+ The FFT-efficient padded shape actually used for computation. Each dimension
65+ of ``full_shape`` is independently rounded up to the next fast length for
66+ real FFTs using ``scipy.fft.next_fast_len``. This shape defines the size of
67+ all arrays sent to and returned from FFT space.
68+
69+ Note that even FFT sizes are currently incremented to odd sizes as a
70+ workaround for kernel-centering issues with even-sized kernels. This is an
71+ implementation detail and should be replaced by correct internal padding
72+ and centering logic.
73+
74+ After determining ``fft_shape``, the input mask is padded accordingly and a
75+ *blurring mask* is derived. The blurring mask identifies pixels that are outside
76+ the original unmasked region but receive non-zero flux due to convolution with
77+ the kernel. These pixels must be retained when mapping results back to the
78+ masked domain to ensure correct convolution near mask boundaries.
5579
5680 Parameters
5781 ----------
82+ kernel
83+ The 2D convolution kernel (e.g. a PSF). If a 1D kernel is provided, it is
84+ internally promoted to a minimal 2D kernel.
5885 mask
59- A 2D mask where False indicates unmasked pixels (valid data) and True
60- indicates masked pixels. The bounding-box of the False region is used
61- to compute the convolution region .
86+ A 2D boolean mask where False values indicate unmasked (valid) pixels and
87+ True values indicate masked pixels. The spatial extent of False pixels
88+ defines the region of the image that is embedded into FFT space .
6289
63- Returns
64- -------
90+ Attributes
91+ ----------
6592 fft_shape
66- The FFT-friendly padded shape for efficient convolution.
93+ The FFT-friendly padded shape used for all Fourier transforms.
94+ mask
95+ The input mask padded to ``fft_shape``, with masked pixels set to True.
96+ blurring_mask
97+ A derived mask identifying pixels that are masked in the original input
98+ but receive flux due to convolution with the kernel.
99+ fft_kernel
100+ The real FFT of the padded kernel, used for efficient convolution in
101+ Fourier space.
102+ fft_kernel_mapping
103+ A broadcast-ready view of ``fft_kernel`` for multi-channel convolution.
67104 """
68105 if len (kernel ) == 1 :
69106 kernel = kernel .resized_from (new_shape = (3 , 3 ))
@@ -111,53 +148,66 @@ def __init__(
111148 ** kwargs ,
112149 ):
113150 """
114- A 2D convolution kernel stored as an array of values paired to a uniform 2D mask.
151+ A 2D convolution kernel paired with a mask, providing real-space and FFT-based
152+ convolution of images or mapping matrices.
115153
116- The ``Convolver`` is a subclass of ``Array2D`` with additional methods for performing
117- point spread function (PSF) convolution of images or mapping matrices . Each entry of
118- the kernel corresponds to a PSF value at the centre of a pixel in the unmasked grid.
154+ The ``Convolver`` is a subclass of ``Array2D`` with additional methods for
155+ performing point spread function (PSF) convolution. Each entry of the kernel
156+ corresponds to the PSF value at the centre of a pixel on a uniform 2D grid.
119157
120158 Two convolution modes are supported:
121159
122- - **Real-space convolution**: performed directly via sliding-window summation or
123- ``jax.scipy.signal.convolve``. This is exact but can be slow for large kernels.
124- - **FFT convolution**: performed by transforming both the kernel and the input image
125- into Fourier space, multiplying, and transforming back. This is typically faster
126- for kernels larger than ~5×5, but requires careful zero-padding.
160+ - **Real-space convolution**:
161+ Performed directly via sliding-window summation or
162+ ``jax.scipy.signal.convolve``. This mode is exact and requires no padding,
163+ but becomes computationally expensive for large kernels.
127164
128- When using FFT convolution, the input image and mask are automatically padded such
129- that the FFT avoids circular wrap-around artefacts. This padding is computed from the
130- kernel size via :meth:`fft_shape_from`. The padded shape is stored in ``fft_shape``.
131- If FFT convolution is attempted without precomputing and applying this padding,
132- an exception is raised to avoid silent shape mismatches.
165+ - **FFT-based convolution**:
166+ Performed by embedding the input image and kernel into padded arrays,
167+ transforming them to Fourier space, multiplying, and transforming back.
168+ This mode is typically faster for kernels larger than approximately 5×5,
169+ but requires careful handling of padding, masking, and kernel centering.
170+
171+ All logic related to FFT padding, mask expansion, linear (non-circular)
172+ convolution, and blurring-mask construction is handled by
173+ ``ConvolverState``. See the ``ConvolverState`` docstring for a detailed
174+ description of how masked real-space data are mapped to and from FFT space.
175+
176+ When FFT convolution is enabled, the ``Convolver`` expects a corresponding
177+ ``ConvolverState`` defining the FFT geometry. The padded FFT shape is stored
178+ in ``state.fft_shape`` and must be consistent with the shape of any arrays
179+ passed for convolution. Attempting FFT convolution without a valid state
180+ will raise an exception to avoid silent shape or alignment errors.
133181
134182 Parameters
135183 ----------
136184 kernel
137- The raw 2D kernel values. Can be normalised to sum to unity if ``normalize=True``.
138- mask
139- The 2D mask associated with the kernel, defining the pixels each kernel value is
140- paired with.
141- header
142- Optional metadata (e.g. FITS header) associated with the kernel .
185+ The raw 2D kernel values. These represent the PSF sampled at pixel
186+ centres and may be normalised to sum to unity if ``normalize=True``.
187+ state
188+ Optional ``ConvolverState`` instance defining FFT padding, mask
189+ expansion, and kernel Fourier transforms. Required when using FFT
190+ convolution .
143191 normalize
144- If True, the kernel values are rescaled such that they sum to 1.0 .
192+ If True, the kernel values are rescaled such that their sum is unity .
145193 use_fft
146- If True, convolution is performed in Fourier space with zero-padding.
194+ If True, convolution is performed in Fourier space using the provided
195+ ``ConvolverState``.
147196 If False, convolution is performed in real space.
148- If None, the config file default is used.
197+ If None, the default behaviour specified in the configuration is used.
149198 *args, **kwargs
150199 Passed to the ``Array2D`` constructor.
151200
152201 Notes
153202 -----
154- - FFT padding can be disabled globally with ``disable_fft_pad=True`` when
155- constructing ``Imaging`` objects, in which case convolution will either
156- use real space or proceed without padding.
203+ - When performing real-space convolution, the kernel must have odd dimensions
204+ in both axes so that it has a well-defined central pixel.
205+ - When performing FFT convolution, kernel centering, padding, and mask
206+ expansion are handled by ``ConvolverState``.
157207 - Blurring masks ensure that PSF flux spilling outside the main image mask
158208 is included correctly. Omitting them may lead to underestimated PSF wings.
159- - For unit tests with tiny kernels, FFT and real-space convolution may differ
160- slightly due to edge and truncation effects.
209+ - For very small kernels, FFT and real-space convolution may differ slightly
210+ near mask boundaries due to padding and truncation effects.
161211 """
162212 self .kernel = kernel
163213
0 commit comments