Skip to content

Commit ad89496

Browse files
Jammy2211Jammy2211
authored andcommitted
InterferometerSparseLinAlg -> InterferometerSparseOperator
1 parent 01524d0 commit ad89496

File tree

2 files changed

+150
-8
lines changed

2 files changed

+150
-8
lines changed

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,59 @@ class InterferometerSparseLinAlg:
555555
batch_size: int
556556
w_dtype: "jax.numpy.dtype"
557557
Khat: "jax.Array" # (2y, 2x), complex
558-
558+
"""
559+
Cached FFT operator state for fast interferometer curvature-matrix assembly.
560+
561+
This class packages *static* quantities needed to apply the interferometer
562+
W~ operator efficiently using FFTs, so that repeated likelihood evaluations
563+
do not redo expensive precomputation.
564+
565+
Conceptually, the interferometer W~ operator is a translationally-invariant
566+
linear operator on a rectangular real-space grid, constructed from the
567+
`nufft_precision_operator` (a 2D array of correlation values on pixel offsets).
568+
By taking an FFT of this preload, the operator can be applied to batches of
569+
images via elementwise multiplication in Fourier space:
570+
571+
apply_W(F) = IFFT( FFT(F_pad) * Khat )
572+
573+
where `F_pad` is a (2y, 2x) padded version of `F` and `Khat = FFT(nufft_precision_operator)`.
574+
575+
The curvature matrix for a pixelization (mapper) is then assembled from sparse
576+
mapping triplets without forming dense mapping matrices:
577+
578+
C = A^T W A
579+
580+
where A is the sparse mapping from source pixels to image pixels.
581+
582+
Caching / validity
583+
------------------
584+
Instances are safe to cache and reuse as long as all of the following remain fixed:
585+
586+
- `nufft_precision_operator` (hence `Khat`)
587+
- the definition of the rectangular FFT grid (y_shape, x_shape)
588+
- dtype / precision (float32 vs float64)
589+
- `batch_size`
590+
591+
Parameters stored
592+
-----------------
593+
dirty_image
594+
Convenience field for associated dirty image data (not used directly in
595+
curvature assembly in this method). Stored as a NumPy array to match
596+
upstream interfaces.
597+
y_shape, x_shape
598+
Shape of the *rectangular* real-space grid (not the masked slim grid).
599+
M
600+
Number of rectangular pixels, M = y_shape * x_shape.
601+
batch_size
602+
Number of source-pixel columns assembled and operated on per block.
603+
Larger batch sizes improve throughput on GPU but increase memory usage.
604+
w_dtype
605+
Floating-point dtype for weights and accumulations (e.g. float64).
606+
Khat
607+
FFT of the curvature preload, shape (2y_shape, 2x_shape), complex.
608+
This is the frequency-domain representation of the W~ operator kernel.
609+
"""
610+
559611
@classmethod
560612
def from_nufft_precision_operator(
561613
self,
@@ -564,6 +616,41 @@ def from_nufft_precision_operator(
564616
*,
565617
batch_size: int = 128,
566618
):
619+
"""
620+
Construct an `InterferometerSparseLinAlg` from a curvature-preload array.
621+
622+
This is the standard factory used in interferometer inversions.
623+
624+
The curvature preload is assumed to be defined on a (2y, 2x) rectangular
625+
grid of pixel offsets, where y and x correspond to the *unmasked extent*
626+
of the real-space grid. The preload is FFT'd once to obtain `Khat`, which
627+
is then reused for every subsequent curvature matrix build.
628+
629+
Parameters
630+
----------
631+
nufft_precision_operator
632+
Real-valued array of shape (2y, 2x) encoding the W~ operator in real
633+
space as a function of pixel offsets. The shape must be even in both
634+
axes so that y_shape = H2//2 and x_shape = W2//2 are integers.
635+
dirty_image
636+
The dirty image associated with the dataset (or any convenient
637+
reference image). Not required for curvature computation itself,
638+
but commonly stored alongside the state for debugging / plotting.
639+
batch_size
640+
Number of source-pixel columns processed per block when assembling
641+
the curvature matrix. Higher values typically improve GPU efficiency
642+
but increase intermediate memory usage.
643+
644+
Returns
645+
-------
646+
InterferometerSparseLinAlg
647+
Immutable cached state object containing shapes and FFT kernel `Khat`.
648+
649+
Raises
650+
------
651+
ValueError
652+
If `nufft_precision_operator` does not have even shape in both dimensions.
653+
"""
567654
import jax.numpy as jnp
568655

569656
H2, W2 = nufft_precision_operator.shape
@@ -596,12 +683,67 @@ def curvature_matrix_via_sparse_operator_from(
596683
fft_index_for_masked_pixel: np.ndarray,
597684
):
598685
"""
599-
Compute curvature matrix for an interferometer inversion using a precomputed FFT state.
600-
601-
IMPORTANT
602-
---------
603-
- COO construction is unchanged from the known-working implementation
604-
- Only FFT- and geometry-related quantities are taken from `fft_state`
686+
Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator.
687+
688+
This method computes the mapper (pixelization) curvature matrix without
689+
forming a dense mapping matrix. Instead, it uses fixed-length mapping
690+
arrays (pixel indexes + weights per masked pixel) which define a sparse
691+
mapping operator A in COO-like form.
692+
693+
Algorithm outline
694+
-----------------
695+
Let S be the number of source pixels and M be the number of rectangular
696+
real-space pixels.
697+
698+
1) Build a fixed-length COO stream from the mapping arrays:
699+
rows_rect[k] : rectangular pixel index (0..M-1)
700+
cols[k] : source pixel index (0..S-1)
701+
vals[k] : mapping weight
702+
Invalid mappings (cols < 0 or cols >= S) are masked out.
703+
704+
2) Process source-pixel columns in blocks of width `batch_size`:
705+
- Scatter the block’s source columns into a dense (M, batch_size) array F.
706+
- Apply the W~ operator by FFT:
707+
G = apply_W(F)
708+
- Project back with Aᵀ via segmented reductions:
709+
C[:, start:start+B] = Aᵀ G
710+
711+
3) Symmetrize the result:
712+
C <- 0.5 * (C + Cᵀ)
713+
714+
Parameters
715+
----------
716+
pix_indexes_for_sub_slim_index
717+
Integer array of shape (M_masked, Pmax).
718+
For each masked (slim) image pixel, stores the source-pixel indices
719+
involved in the interpolation / mapping stencil. Invalid entries
720+
should be set to -1.
721+
pix_weights_for_sub_slim_index
722+
Floating array of shape (M_masked, Pmax).
723+
Weights corresponding to `pix_indexes_for_sub_slim_index`.
724+
These should already include any oversampling normalisation (e.g.
725+
sub-pixel fractions) required by the mapper.
726+
pix_pixels
727+
Number of source pixels, S.
728+
fft_index_for_masked_pixel
729+
Integer array of shape (M_masked,).
730+
Maps each masked (slim) image pixel index to its corresponding
731+
rectangular-grid flat index (0..M-1). This embeds the masked pixel
732+
ordering into the FFT-friendly rectangular grid.
733+
734+
Returns
735+
-------
736+
jax.Array
737+
Curvature matrix of shape (S, S), symmetric.
738+
739+
Notes
740+
-----
741+
- The inner computation is written in JAX and is intended to be jitted.
742+
For best performance, keep `batch_size` fixed (static) across calls.
743+
- Choosing `batch_size` as a divisor of S avoids a smaller tail block,
744+
but correctness does not require that if the implementation masks the tail.
745+
- This method uses FFTs on padded (2y, 2x) arrays; memory use scales with
746+
batch_size and grid size.
605747
"""
606748

607749
import jax.numpy as jnp

test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test__curvature_matrix_via_psf_precision_operator_from():
120120

121121
pix_weights_for_sub_slim_index = np.ones(shape=(9, 1))
122122

123-
sparse_operator = aa.InterferometerSparseLinAlg.from_nufft_precision_operator(
123+
sparse_operator = aa.InterferometerSparseOperator.from_nufft_precision_operator(
124124
nufft_precision_operator=nufft_precision_operator,
125125
dirty_image=None,
126126
)

0 commit comments

Comments
 (0)