Skip to content

Commit 7d258ce

Browse files
Jammy2211Jammy2211
authored andcommitted
merge
2 parents 115257c + a655bb9 commit 7d258ce

37 files changed

+2455
-1622
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ jobs:
6666
- name: Run tests
6767
run: |
6868
export ROOT_DIR=`pwd`
69+
export JAX_ENABLE_X64=True
6970
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoConf
7071
export PYTHONPATH=$PYTHONPATH:$ROOT_DIR/PyAutoArray
7172
pushd PyAutoArray

autoarray/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99
from . import util
1010
from . import fixtures
1111
from . import mock as m
12-
from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible
12+
1313
from .dataset import preprocess
1414
from .dataset.abstract.dataset import AbstractDataset
15-
from .dataset.abstract.w_tilde import AbstractWTilde
1615
from .dataset.grids import GridsInterface
1716
from .dataset.imaging.dataset import Imaging
1817
from .dataset.imaging.simulator import SimulatorImaging
19-
from .dataset.imaging.w_tilde import WTildeImaging
2018
from .dataset.interferometer.dataset import Interferometer
2119
from .dataset.interferometer.simulator import SimulatorInterferometer
22-
from .dataset.interferometer.w_tilde import WTildeInterferometer
2320
from .dataset.dataset_model import DatasetModel
2421
from .fit.fit_dataset import AbstractFit
2522
from .fit.fit_dataset import FitDataset
@@ -46,9 +43,15 @@
4643
from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh
4744
from .inversion.pixelization.mesh.abstract import AbstractMesh
4845
from .inversion.inversion.imaging.mapping import InversionImagingMapping
49-
from .inversion.inversion.imaging.w_tilde import InversionImagingWTilde
50-
from .inversion.inversion.interferometer.w_tilde import InversionInterferometerWTilde
46+
from .inversion.inversion.imaging.sparse import InversionImagingSparse
47+
from .inversion.inversion.imaging.inversion_imaging_util import ImagingSparseOperator
48+
from .inversion.inversion.interferometer.sparse import (
49+
InversionInterferometerSparse,
50+
)
5151
from .inversion.inversion.interferometer.mapping import InversionInterferometerMapping
52+
from .inversion.inversion.interferometer.inversion_interferometer_util import (
53+
InterferometerSparseOperator,
54+
)
5255
from .inversion.linear_obj.linear_obj import LinearObj
5356
from .inversion.linear_obj.func_list import AbstractLinearObjFuncList
5457
from .mask.derive.indexes_2d import DeriveIndexes2D

autoarray/dataset/abstract/w_tilde.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

autoarray/dataset/grids.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(
1717
over_sample_size_lp: Union[int, Array2D],
1818
over_sample_size_pixelization: Union[int, Array2D],
1919
psf: Optional[Kernel2D] = None,
20-
use_w_tilde: bool = False,
20+
use_sparse_operator: bool = False,
2121
):
2222
"""
2323
Contains grids of (y,x) Cartesian coordinates at the centre of every pixel in the dataset's image and
@@ -66,7 +66,7 @@ def __init__(
6666
self._blurring = None
6767
self._border_relocator = None
6868

69-
self.use_w_tilde = use_w_tilde
69+
self.use_sparse_operator = use_sparse_operator
7070

7171
@property
7272
def lp(self):
@@ -120,7 +120,7 @@ def border_relocator(self) -> BorderRelocator:
120120
self._border_relocator = BorderRelocator(
121121
mask=self.mask,
122122
sub_size=self.over_sample_size_pixelization,
123-
use_w_tilde=self.use_w_tilde,
123+
use_sparse_operator=self.use_sparse_operator,
124124
)
125125

126126
return self._border_relocator

autoarray/dataset/imaging/dataset.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
from pathlib import Path
44
from typing import Optional, Union
55

6-
from autoconf import cached_property
7-
86
from autoarray.dataset.abstract.dataset import AbstractDataset
97
from autoarray.dataset.grids import GridsDataset
10-
from autoarray.dataset.imaging.w_tilde import WTildeImaging
8+
from autoarray.inversion.inversion.imaging.inversion_imaging_util import (
9+
ImagingSparseOperator,
10+
)
1111
from autoarray.structures.arrays.uniform_2d import Array2D
1212
from autoarray.structures.arrays.kernel_2d import Kernel2D
1313
from autoarray.mask.mask_2d import Mask2D
1414
from autoarray import type as ty
1515

1616
from autoarray import exc
1717
from autoarray.operators.over_sampling import over_sample_util
18-
from autoarray.inversion.inversion.imaging import inversion_imaging_numba_util
18+
19+
from autoarray.inversion.inversion.imaging import inversion_imaging_util
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -32,7 +33,7 @@ def __init__(
3233
disable_fft_pad: bool = True,
3334
use_normalized_psf: Optional[bool] = True,
3435
check_noise_map: bool = True,
35-
w_tilde: Optional[WTildeImaging] = None,
36+
sparse_operator: Optional[ImagingSparseOperator] = None,
3637
):
3738
"""
3839
An imaging dataset, containing the image data, noise-map, PSF and associated quantities
@@ -86,9 +87,9 @@ def __init__(
8687
the PSF kernel does not change the overall normalization of the image when it is convolved with it.
8788
check_noise_map
8889
If True, the noise-map is checked to ensure all values are above zero.
89-
w_tilde
90-
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
91-
noise-map values given the PSF (see `inversion.inversion_util`). Pass the `WTildeImaging` object here to
90+
sparse_operator
91+
The sparse linear algebra formalism of the linear algebra equations precomputes the convolution of every pair of masked
92+
noise-map values given the PSF (see `inversion.inversion_util`). Pass the `ImagingSparseOperator` object here to
9293
enable this linear algebra formalism for pixelized reconstructions.
9394
"""
9495

@@ -191,17 +192,17 @@ def __init__(
191192
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
192193
raise exc.KernelException("Kernel2D Kernel2D must be odd")
193194

194-
use_w_tilde = True if w_tilde is not None else False
195+
use_sparse_operator = True if sparse_operator is not None else False
195196

196197
self.grids = GridsDataset(
197198
mask=self.data.mask,
198199
over_sample_size_lp=self.over_sample_size_lp,
199200
over_sample_size_pixelization=self.over_sample_size_pixelization,
200201
psf=self.psf,
201-
use_w_tilde=use_w_tilde,
202+
use_sparse_operator=use_sparse_operator,
202203
)
203204

204-
self.w_tilde = w_tilde
205+
self.sparse_operator = sparse_operator
205206

206207
@classmethod
207208
def from_fits(
@@ -474,9 +475,13 @@ def apply_over_sampling(
474475

475476
return dataset
476477

477-
def apply_w_tilde(self, disable_fft_pad: bool = False):
478+
def apply_sparse_operator(
479+
self,
480+
batch_size: int = 128,
481+
disable_fft_pad: bool = False,
482+
):
478483
"""
479-
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked
484+
The sparse linear algebra formalism precomputes the convolution of every pair of masked
480485
noise-map values given the PSF (see `inversion.inversion_util`).
481486
482487
The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once
@@ -487,12 +492,66 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
487492
488493
Returns
489494
-------
490-
WTildeImaging
491-
Precomputed values used for the w tilde formalism of linear algebra calculations.
495+
batch_size
496+
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
497+
which can be reduced to produce lower memory usage at the cost of speed
498+
disable_fft_pad
499+
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
500+
which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
501+
performed and the image is used as-is. This is normally used to avoid repadding data that has already been
502+
padded.
503+
use_jax
504+
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
492505
"""
493506

494-
logger.info("IMAGING - Computing W-Tilde... May take a moment.")
507+
sparse_operator = (
508+
inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf(
509+
data=self.data,
510+
noise_map=self.noise_map,
511+
psf=self.psf.native,
512+
batch_size=batch_size,
513+
)
514+
)
515+
516+
return Imaging(
517+
data=self.data,
518+
noise_map=self.noise_map,
519+
psf=self.psf,
520+
noise_covariance_matrix=self.noise_covariance_matrix,
521+
over_sample_size_lp=self.over_sample_size_lp,
522+
over_sample_size_pixelization=self.over_sample_size_pixelization,
523+
disable_fft_pad=disable_fft_pad,
524+
check_noise_map=False,
525+
sparse_operator=sparse_operator,
526+
)
495527

528+
def apply_sparse_operator_cpu(
529+
self,
530+
disable_fft_pad: bool = False,
531+
):
532+
"""
533+
The sparse linear algebra formalism precomputes the convolution of every pair of masked
534+
noise-map values given the PSF (see `inversion.inversion_util`).
535+
536+
The `WTilde` object stores these precomputed values in the imaging dataset ensuring they are only computed once
537+
per analysis.
538+
539+
This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
540+
ensuring efficient set up of the `Imaging` class.
541+
542+
Returns
543+
-------
544+
batch_size
545+
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
546+
which can be reduced to produce lower memory usage at the cost of speed
547+
disable_fft_pad
548+
The FFT PSF convolution is optimal for a certain 2D FFT padding or trimming,
549+
which places the fewest zeros around the image. If this is set to `True`, this optimal padding is not
550+
performed and the image is used as-is. This is normally used to avoid repadding data that has already been
551+
padded.
552+
use_jax
553+
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
554+
"""
496555
try:
497556
import numba
498557
except ModuleNotFoundError:
@@ -504,20 +563,24 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
504563
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
505564
)
506565

566+
from autoarray.inversion.inversion.imaging_numba import (
567+
inversion_imaging_numba_util,
568+
)
569+
507570
(
508-
curvature_preload,
571+
psf_precision_operator_sparse,
509572
indexes,
510573
lengths,
511-
) = inversion_imaging_numba_util.w_tilde_curvature_preload_imaging_from(
574+
) = inversion_imaging_numba_util.psf_precision_operator_sparse_from(
512575
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
513576
kernel_native=np.array(self.psf.native.array).astype("float64"),
514577
native_index_for_slim_index=np.array(
515578
self.mask.derive_indexes.native_for_slim
516579
).astype("int"),
517580
)
518581

519-
w_tilde = WTildeImaging(
520-
curvature_preload=curvature_preload,
582+
sparse_operator = inversion_imaging_numba_util.SparseLinAlgImagingNumba(
583+
psf_precision_operator_sparse=psf_precision_operator_sparse,
521584
indexes=indexes.astype("int"),
522585
lengths=lengths.astype("int"),
523586
noise_map=self.noise_map,
@@ -534,7 +597,7 @@ def apply_w_tilde(self, disable_fft_pad: bool = False):
534597
over_sample_size_pixelization=self.over_sample_size_pixelization,
535598
disable_fft_pad=disable_fft_pad,
536599
check_noise_map=False,
537-
w_tilde=w_tilde,
600+
sparse_operator=sparse_operator,
538601
)
539602

540603
def output_to_fits(

autoarray/dataset/imaging/w_tilde.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

0 commit comments

Comments
 (0)