@@ -555,7 +555,59 @@ class InterferometerSparseLinAlg:
555555 batch_size : int
556556 w_dtype : "jax.numpy.dtype"
557557 Khat : "jax.Array" # (2y, 2x), complex
558-
558+ """
559+ Cached FFT operator state for fast interferometer curvature-matrix assembly.
560+
561+ This class packages *static* quantities needed to apply the interferometer
562+ W~ operator efficiently using FFTs, so that repeated likelihood evaluations
563+ do not redo expensive precomputation.
564+
565+ Conceptually, the interferometer W~ operator is a translationally-invariant
566+ linear operator on a rectangular real-space grid, constructed from the
567+ `nufft_precision_operator` (a 2D array of correlation values on pixel offsets).
568+ By taking an FFT of this preload, the operator can be applied to batches of
569+ images via elementwise multiplication in Fourier space:
570+
571+ apply_W(F) = IFFT( FFT(F_pad) * Khat )
572+
573+ where `F_pad` is a (2y, 2x) padded version of `F` and `Khat = FFT(nufft_precision_operator)`.
574+
575+ The curvature matrix for a pixelization (mapper) is then assembled from sparse
576+ mapping triplets without forming dense mapping matrices:
577+
578+ C = A^T W A
579+
580+ where A is the sparse mapping from source pixels to image pixels.
581+
582+ Caching / validity
583+ ------------------
584+ Instances are safe to cache and reuse as long as all of the following remain fixed:
585+
586+ - `nufft_precision_operator` (hence `Khat`)
587+ - the definition of the rectangular FFT grid (y_shape, x_shape)
588+ - dtype / precision (float32 vs float64)
589+ - `batch_size`
590+
591+ Parameters stored
592+ -----------------
593+ dirty_image
594+ Convenience field for associated dirty image data (not used directly in
595+ curvature assembly in this method). Stored as a NumPy array to match
596+ upstream interfaces.
597+ y_shape, x_shape
598+ Shape of the *rectangular* real-space grid (not the masked slim grid).
599+ M
600+ Number of rectangular pixels, M = y_shape * x_shape.
601+ batch_size
602+ Number of source-pixel columns assembled and operated on per block.
603+ Larger batch sizes improve throughput on GPU but increase memory usage.
604+ w_dtype
605+ Floating-point dtype for weights and accumulations (e.g. float64).
606+ Khat
607+ FFT of the curvature preload, shape (2y_shape, 2x_shape), complex.
608+ This is the frequency-domain representation of the W~ operator kernel.
609+ """
610+
559611 @classmethod
560612 def from_nufft_precision_operator (
561613 self ,
@@ -564,6 +616,41 @@ def from_nufft_precision_operator(
564616 * ,
565617 batch_size : int = 128 ,
566618 ):
619+ """
620+ Construct an `InterferometerSparseLinAlg` from a curvature-preload array.
621+
622+ This is the standard factory used in interferometer inversions.
623+
624+ The curvature preload is assumed to be defined on a (2y, 2x) rectangular
625+ grid of pixel offsets, where y and x correspond to the *unmasked extent*
626+ of the real-space grid. The preload is FFT'd once to obtain `Khat`, which
627+ is then reused for every subsequent curvature matrix build.
628+
629+ Parameters
630+ ----------
631+ nufft_precision_operator
632+ Real-valued array of shape (2y, 2x) encoding the W~ operator in real
633+ space as a function of pixel offsets. The shape must be even in both
634+ axes so that y_shape = H2//2 and x_shape = W2//2 are integers.
635+ dirty_image
636+ The dirty image associated with the dataset (or any convenient
637+ reference image). Not required for curvature computation itself,
638+ but commonly stored alongside the state for debugging / plotting.
639+ batch_size
640+ Number of source-pixel columns processed per block when assembling
641+ the curvature matrix. Higher values typically improve GPU efficiency
642+ but increase intermediate memory usage.
643+
644+ Returns
645+ -------
646+ InterferometerSparseLinAlg
647+ Immutable cached state object containing shapes and FFT kernel `Khat`.
648+
649+ Raises
650+ ------
651+ ValueError
652+ If `nufft_precision_operator` does not have even shape in both dimensions.
653+ """
567654 import jax .numpy as jnp
568655
569656 H2 , W2 = nufft_precision_operator .shape
@@ -596,12 +683,67 @@ def curvature_matrix_via_sparse_operator_from(
596683 fft_index_for_masked_pixel : np .ndarray ,
597684 ):
598685 """
599- Compute curvature matrix for an interferometer inversion using a precomputed FFT state.
600-
601- IMPORTANT
602- ---------
603- - COO construction is unchanged from the known-working implementation
604- - Only FFT- and geometry-related quantities are taken from `fft_state`
686+ Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator.
687+
688+ This method computes the mapper (pixelization) curvature matrix without
689+ forming a dense mapping matrix. Instead, it uses fixed-length mapping
690+ arrays (pixel indexes + weights per masked pixel) which define a sparse
691+ mapping operator A in COO-like form.
692+
693+ Algorithm outline
694+ -----------------
695+ Let S be the number of source pixels and M be the number of rectangular
696+ real-space pixels.
697+
698+ 1) Build a fixed-length COO stream from the mapping arrays:
699+ rows_rect[k] : rectangular pixel index (0..M-1)
700+ cols[k] : source pixel index (0..S-1)
701+ vals[k] : mapping weight
702+ Invalid mappings (cols < 0 or cols >= S) are masked out.
703+
704+ 2) Process source-pixel columns in blocks of width `batch_size`:
705+ - Scatter the block’s source columns into a dense (M, batch_size) array F.
706+ - Apply the W~ operator by FFT:
707+ G = apply_W(F)
708+ - Project back with Aᵀ via segmented reductions:
709+ C[:, start:start+B] = Aᵀ G
710+
711+ 3) Symmetrize the result:
712+ C <- 0.5 * (C + Cᵀ)
713+
714+ Parameters
715+ ----------
716+ pix_indexes_for_sub_slim_index
717+ Integer array of shape (M_masked, Pmax).
718+ For each masked (slim) image pixel, stores the source-pixel indices
719+ involved in the interpolation / mapping stencil. Invalid entries
720+ should be set to -1.
721+ pix_weights_for_sub_slim_index
722+ Floating array of shape (M_masked, Pmax).
723+ Weights corresponding to `pix_indexes_for_sub_slim_index`.
724+ These should already include any oversampling normalisation (e.g.
725+ sub-pixel fractions) required by the mapper.
726+ pix_pixels
727+ Number of source pixels, S.
728+ fft_index_for_masked_pixel
729+ Integer array of shape (M_masked,).
730+ Maps each masked (slim) image pixel index to its corresponding
731+ rectangular-grid flat index (0..M-1). This embeds the masked pixel
732+ ordering into the FFT-friendly rectangular grid.
733+
734+ Returns
735+ -------
736+ jax.Array
737+ Curvature matrix of shape (S, S), symmetric.
738+
739+ Notes
740+ -----
741+ - The inner computation is written in JAX and is intended to be jitted.
742+ For best performance, keep `batch_size` fixed (static) across calls.
743+ - Choosing `batch_size` as a divisor of S avoids a smaller tail block,
744+ but correctness does not require that if the implementation masks the tail.
745+ - This method uses FFTs on padded (2y, 2x) arrays; memory use scales with
746+ batch_size and grid size.
605747 """
606748
607749 import jax .numpy as jnp
0 commit comments