33from pathlib import Path
44from typing import Optional , Union
55
6- from autoconf import cached_property
7-
86from autoarray .dataset .abstract .dataset import AbstractDataset
97from autoarray .dataset .grids import GridsDataset
10- from autoarray .dataset .imaging .w_tilde import WTildeImaging
8+ from autoarray .inversion .inversion .imaging .inversion_imaging_util import (
9+ ImagingSparseOperator ,
10+ )
1111from autoarray .structures .arrays .uniform_2d import Array2D
1212from autoarray .structures .arrays .kernel_2d import Kernel2D
1313from autoarray .mask .mask_2d import Mask2D
1414from autoarray import type as ty
1515
1616from autoarray import exc
1717from autoarray .operators .over_sampling import over_sample_util
18- from autoarray .inversion .inversion .imaging import inversion_imaging_numba_util
18+
19+ from autoarray .inversion .inversion .imaging import inversion_imaging_util
1920
2021logger = logging .getLogger (__name__ )
2122
@@ -32,7 +33,7 @@ def __init__(
3233 disable_fft_pad : bool = True ,
3334 use_normalized_psf : Optional [bool ] = True ,
3435 check_noise_map : bool = True ,
35- w_tilde : Optional [WTildeImaging ] = None ,
36+ sparse_operator : Optional [ImagingSparseOperator ] = None ,
3637 ):
3738 """
3839 An imaging dataset, containing the image data, noise-map, PSF and associated quantities
@@ -86,9 +87,9 @@ def __init__(
8687 the PSF kernel does not change the overall normalization of the image when it is convolved with it.
8788 check_noise_map
8889 If True, the noise-map is checked to ensure all values are above zero.
89- w_tilde
90- The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
91- noise-map values given the PSF (see `inversion.inversion_util`). Pass the `WTildeImaging ` object here to
90+ sparse_operator
91+ The sparse linear algebra formalism of the linear algebra equations precomputes the convolution of every pair of masked
92+ noise-map values given the PSF (see `inversion.inversion_util`). Pass the `ImagingSparseOperator ` object here to
9293 enable this linear algebra formalism for pixelized reconstructions.
9394 """
9495
@@ -191,17 +192,17 @@ def __init__(
191192 if psf .mask .shape [0 ] % 2 == 0 or psf .mask .shape [1 ] % 2 == 0 :
192193 raise exc .KernelException ("Kernel2D Kernel2D must be odd" )
193194
194- use_w_tilde = True if w_tilde is not None else False
195+ use_sparse_operator = True if sparse_operator is not None else False
195196
196197 self .grids = GridsDataset (
197198 mask = self .data .mask ,
198199 over_sample_size_lp = self .over_sample_size_lp ,
199200 over_sample_size_pixelization = self .over_sample_size_pixelization ,
200201 psf = self .psf ,
201- use_w_tilde = use_w_tilde ,
202+ use_sparse_operator = use_sparse_operator ,
202203 )
203204
204- self .w_tilde = w_tilde
205+ self .sparse_operator = sparse_operator
205206
206207 @classmethod
207208 def from_fits (
@@ -474,9 +475,13 @@ def apply_over_sampling(
474475
475476 return dataset
476477
477- def apply_w_tilde (self , disable_fft_pad : bool = False ):
478+ def apply_sparse_operator (
479+ self ,
480+ batch_size : int = 128 ,
481+ disable_fft_pad : bool = False ,
482+ ):
478483 """
479- The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
484+ The sparse linear algebra formalism precomputes the convolution of every pair of masked
480485 noise-map values given the PSF (see `inversion.inversion_util`).
481486
482487 The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once
@@ -487,12 +492,66 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
487492
488493 Returns
489494 -------
490- WTildeImaging
491- Precomputed values used for the w tilde formalism of linear algebra calculations.
495+ batch_size
496+ The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
497+ which can be reduced to produce lower memory usage at the cost of speed
498+ disable_fft_pad
499+ The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
500+ which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
501+ performed and the image is used as-is. This is normally used to avoid repadding data that has already been
502+ padded.
503+ use_jax
504+ Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
492505 """
493506
494- logger .info ("IMAGING - Computing W-Tilde... May take a moment." )
507+ sparse_operator = (
508+ inversion_imaging_util .ImagingSparseOperator .from_noise_map_and_psf (
509+ data = self .data ,
510+ noise_map = self .noise_map ,
511+ psf = self .psf .native ,
512+ batch_size = batch_size ,
513+ )
514+ )
515+
516+ return Imaging (
517+ data = self .data ,
518+ noise_map = self .noise_map ,
519+ psf = self .psf ,
520+ noise_covariance_matrix = self .noise_covariance_matrix ,
521+ over_sample_size_lp = self .over_sample_size_lp ,
522+ over_sample_size_pixelization = self .over_sample_size_pixelization ,
523+ disable_fft_pad = disable_fft_pad ,
524+ check_noise_map = False ,
525+ sparse_operator = sparse_operator ,
526+ )
495527
528+ def apply_sparse_operator_cpu (
529+ self ,
530+ disable_fft_pad : bool = False ,
531+ ):
532+ """
533+ The sparse linear algebra formalism precomputes the convolution of every pair of masked
534+ noise-map values given the PSF (see `inversion.inversion_util`).
535+
536+ The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once
537+ per analysis.
538+
539+ This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
540+ ensuring efficient set up of the `Imaging` class.
541+
542+ Returns
543+ -------
544+ batch_size
545+ The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
546+ which can be reduced to produce lower memory usage at the cost of speed
547+ disable_fft_pad
548+ The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
549+ which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
550+ performed and the image is used as-is. This is normally used to avoid repadding data that has already been
551+ padded.
552+ use_jax
553+ Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
554+ """
496555 try :
497556 import numba
498557 except ModuleNotFoundError :
@@ -504,20 +563,24 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
504563 "https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
505564 )
506565
566+ from autoarray .inversion .inversion .imaging_numba import (
567+ inversion_imaging_numba_util ,
568+ )
569+
507570 (
508- curvature_preload ,
571+ psf_precision_operator_sparse ,
509572 indexes ,
510573 lengths ,
511- ) = inversion_imaging_numba_util .w_tilde_curvature_preload_imaging_from (
574+ ) = inversion_imaging_numba_util .psf_precision_operator_sparse_from (
512575 noise_map_native = np .array (self .noise_map .native .array ).astype ("float64" ),
513576 kernel_native = np .array (self .psf .native .array ).astype ("float64" ),
514577 native_index_for_slim_index = np .array (
515578 self .mask .derive_indexes .native_for_slim
516579 ).astype ("int" ),
517580 )
518581
519- w_tilde = WTildeImaging (
520- curvature_preload = curvature_preload ,
582+ sparse_operator = inversion_imaging_numba_util . SparseLinAlgImagingNumba (
583+ psf_precision_operator_sparse = psf_precision_operator_sparse ,
521584 indexes = indexes .astype ("int" ),
522585 lengths = lengths .astype ("int" ),
523586 noise_map = self .noise_map ,
@@ -534,7 +597,7 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
534597 over_sample_size_pixelization = self .over_sample_size_pixelization ,
535598 disable_fft_pad = disable_fft_pad ,
536599 check_noise_map = False ,
537- w_tilde = w_tilde ,
600+ sparse_operator = sparse_operator ,
538601 )
539602
540603 def output_to_fits (
0 commit comments