Skip to content

Commit e72defb

Browse files
authored
Merge pull request #201 from Jammy2211/feature/interferometer
Feature/interferometer
2 parents 7810f8d + 0a80432 commit e72defb

File tree

18 files changed

+688
-2237
lines changed

18 files changed

+688
-2237
lines changed

autoarray/dataset/interferometer/dataset.py

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
import logging
22
import numpy as np
3-
from pathlib import Path
43
from typing import Optional
54

6-
from autoconf import cached_property
7-
85
from autoconf.fitsable import ndarray_via_fits_from, output_to_fits
96

107
from autoarray.dataset.abstract.dataset import AbstractDataset
118
from autoarray.dataset.interferometer.w_tilde import WTildeInterferometer
129
from autoarray.dataset.grids import GridsDataset
10+
from autoarray.operators.transformer import TransformerDFT
1311
from autoarray.operators.transformer import TransformerNUFFT
1412
from autoarray.mask.mask_2d import Mask2D
1513
from autoarray.structures.visibilities import Visibilities
1614
from autoarray.structures.visibilities import VisibilitiesNoiseMap
1715

18-
from autoarray.inversion.inversion.interferometer import inversion_interferometer_util
19-
2016
from autoarray import exc
17+
from autoarray.inversion.inversion.interferometer import (
18+
inversion_interferometer_util,
19+
)
2120

2221
logger = logging.getLogger(__name__)
2322

@@ -30,8 +29,8 @@ def __init__(
3029
uv_wavelengths: np.ndarray,
3130
real_space_mask: Mask2D,
3231
transformer_class=TransformerNUFFT,
33-
dft_preload_transform: bool = True,
3432
w_tilde: Optional[WTildeInterferometer] = None,
33+
raise_error_dft_visibilities_limit: bool = True,
3534
):
3635
"""
3736
An interferometer dataset, containing the visibilities data, noise-map, real-space msk, Fourier transformer and
@@ -77,9 +76,6 @@ def __init__(
7776
transformer_class
7877
The class of the Fourier Transform which maps images from real space to Fourier space visibilities and
7978
the uv-plane.
80-
dft_preload_transform
81-
If True, precomputes and stores the cosine and sine terms for the Fourier transform.
82-
This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
8379
"""
8480
self.real_space_mask = real_space_mask
8581

@@ -95,11 +91,8 @@ def __init__(
9591
self.transformer = transformer_class(
9692
uv_wavelengths=uv_wavelengths,
9793
real_space_mask=real_space_mask,
98-
preload_transform=dft_preload_transform,
9994
)
10095

101-
self.dft_preload_transform = dft_preload_transform
102-
10396
use_w_tilde = True if w_tilde is not None else False
10497

10598
self.grids = GridsDataset(
@@ -111,6 +104,22 @@ def __init__(
111104

112105
self.w_tilde = w_tilde
113106

107+
if raise_error_dft_visibilities_limit:
108+
if (
109+
self.uv_wavelengths.shape[0] > 10000
110+
and transformer_class == TransformerDFT
111+
):
112+
raise exc.DatasetException(
113+
"""
114+
Interferometer datasets with more than 10,000 visibilities should use the TransformerNUFFT class for
115+
efficient Fourier transforms between real and uv-space. The DFT (Discrete Fourier Transform) is too slow for
116+
large datasets.
117+
118+
If you are certain you want to use the TransformerDFT class, you can disable this error by passing
119+
the input `raise_error_dft_visibilities_limit=False` when loading the Interferometer dataset.
120+
"""
121+
)
122+
114123
@classmethod
115124
def from_fits(
116125
cls,
@@ -122,7 +131,6 @@ def from_fits(
122131
noise_map_hdu=0,
123132
uv_wavelengths_hdu=0,
124133
transformer_class=TransformerNUFFT,
125-
dft_preload_transform: bool = True,
126134
):
127135
"""
128136
Factory for loading the interferometer data_type from .fits files, as well as computing properties like the
@@ -148,10 +156,15 @@ def from_fits(
148156
noise_map=noise_map,
149157
uv_wavelengths=uv_wavelengths,
150158
transformer_class=transformer_class,
151-
dft_preload_transform=dft_preload_transform,
152159
)
153160

154-
def apply_w_tilde(self):
161+
def apply_w_tilde(
162+
self,
163+
curvature_preload=None,
164+
batch_size: int = 128,
165+
show_progress: bool = False,
166+
show_memory: bool = False,
167+
):
155168
"""
156169
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
157170
given the `uv_wavelengths` (see `inversion.inversion_util`).
@@ -162,44 +175,33 @@ def apply_w_tilde(self):
162175
This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
163176
ensuring efficient set up of the `Interferometer` class.
164177
178+
Parameters
179+
----------
180+
curvature_preload
181+
An already computed curvature preload matrix for this dataset (e.g. loaded from hard-disk), to prevent
182+
long recalculations of this matrix for large datasets.
183+
batch_size
184+
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
185+
which can be reduced to produce lower memory usage at the cost of speed.
186+
165187
Returns
166188
-------
167189
WTildeInterferometer
168190
Precomputed values used for the w tilde formalism of linear algebra calculations.
169191
"""
170192

171-
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
172-
173-
try:
174-
import numba
175-
except ModuleNotFoundError:
176-
raise exc.InversionException(
177-
"Inversion w-tilde functionality (pixelized reconstructions) is "
178-
"disabled if numba is not installed.\n\n"
179-
"This is because the run-times without numba are too slow.\n\n"
180-
"Please install numba, which is described at the following web page:\n\n"
181-
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
182-
)
193+
if curvature_preload is None:
183194

184-
curvature_preload = (
185-
inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
186-
noise_map_real=np.array(self.noise_map.real),
187-
uv_wavelengths=np.array(self.uv_wavelengths),
188-
shape_masked_pixels_2d=np.array(
189-
self.transformer.grid.mask.shape_native_masked_pixels
190-
),
191-
grid_radians_2d=np.array(
192-
self.transformer.grid.mask.derive_grid.all_false.in_radians.native
193-
),
194-
)
195-
)
195+
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
196196

197-
w_matrix = inversion_interferometer_util.w_tilde_via_preload_from(
198-
w_tilde_preload=curvature_preload,
199-
native_index_for_slim_index=np.array(
200-
self.real_space_mask.derive_indexes.native_for_slim
201-
).astype("int"),
202-
)
197+
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
198+
noise_map_real=self.noise_map.array.real,
199+
uv_wavelengths=self.uv_wavelengths,
200+
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
201+
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
202+
show_memory=show_memory,
203+
show_progress=show_progress,
204+
)
203205

204206
dirty_image = self.transformer.image_from(
205207
visibilities=self.data.real * self.noise_map.real**-2.0
@@ -208,19 +210,18 @@ def apply_w_tilde(self):
208210
)
209211

210212
w_tilde = WTildeInterferometer(
211-
w_matrix=w_matrix,
212213
curvature_preload=curvature_preload,
213-
dirty_image=np.array(dirty_image.array),
214+
dirty_image=dirty_image.array,
214215
real_space_mask=self.real_space_mask,
216+
batch_size=batch_size,
215217
)
216218

217219
return Interferometer(
218220
real_space_mask=self.real_space_mask,
219221
data=self.data,
220222
noise_map=self.noise_map,
221223
uv_wavelengths=self.uv_wavelengths,
222-
transformer_class=lambda uv_wavelengths, real_space_mask, preload_transform: self.transformer,
223-
dft_preload_transform=self.dft_preload_transform,
224+
transformer_class=lambda uv_wavelengths, real_space_mask: self.transformer,
224225
w_tilde=w_tilde,
225226
)
226227

autoarray/dataset/interferometer/w_tilde.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
class WTildeInterferometer(AbstractWTilde):
88
def __init__(
99
self,
10-
w_matrix: np.ndarray,
1110
curvature_preload: np.ndarray,
1211
dirty_image: np.ndarray,
1312
real_space_mask: Mask2D,
13+
batch_size: int = 128,
1414
):
1515
"""
1616
Packages together all derived data quantities necessary to fit `Interferometer` data using an ` Inversion` via
@@ -34,6 +34,9 @@ def __init__(
3434
real_space_mask
3535
The 2D mask in real-space defining the area where the interferometer data's visibilities are observing
3636
a signal.
37+
batch_size
38+
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
39+
which can be reduced to produce lower memory usage at the cost of speed.
3740
"""
3841
super().__init__(
3942
curvature_preload=curvature_preload,
@@ -42,4 +45,80 @@ def __init__(
4245
self.dirty_image = dirty_image
4346
self.real_space_mask = real_space_mask
4447

45-
self.w_matrix = w_matrix
48+
from autoarray.inversion.inversion.interferometer import (
49+
inversion_interferometer_util,
50+
)
51+
52+
self.fft_state = inversion_interferometer_util.w_tilde_fft_state_from(
53+
curvature_preload=self.curvature_preload, batch_size=batch_size
54+
)
55+
56+
@property
57+
def mask_rectangular_w_tilde(self) -> np.ndarray:
58+
"""
59+
Returns a rectangular boolean mask that tightly bounds the unmasked region
60+
of the interferometer mask.
61+
62+
This rectangular mask is used for computing the W-tilde curvature matrix
63+
via FFT-based convolution, which requires a full rectangular grid.
64+
65+
Pixels outside the bounding box of the original mask are set to True
66+
(masked), and pixels inside are False (unmasked).
67+
68+
Returns
69+
-------
70+
np.ndarray
71+
Boolean mask of shape (Ny, Nx), where False denotes unmasked pixels.
72+
"""
73+
mask = self.real_space_mask
74+
75+
ys, xs = np.where(~mask)
76+
77+
y_min, y_max = ys.min(), ys.max()
78+
x_min, x_max = xs.min(), xs.max()
79+
80+
rect_mask = np.ones(mask.shape, dtype=bool)
81+
rect_mask[y_min : y_max + 1, x_min : x_max + 1] = False
82+
83+
return rect_mask
84+
85+
@property
86+
def rect_index_for_mask_index(self) -> np.ndarray:
87+
"""
88+
Mapping from masked-grid pixel indices to rectangular-grid pixel indices.
89+
90+
This array enables extraction of a curvature matrix computed on a full
91+
rectangular grid back to the original masked grid.
92+
93+
If:
94+
- C_rect is the curvature matrix computed on the rectangular grid
95+
- idx = rect_index_for_mask_index
96+
97+
then the masked curvature matrix is:
98+
C_mask = C_rect[idx[:, None], idx[None, :]]
99+
100+
Returns
101+
-------
102+
np.ndarray
103+
Array of shape (N_masked_pixels,), where each entry gives the
104+
corresponding index in the rectangular grid (row-major order).
105+
"""
106+
mask = self.real_space_mask
107+
rect_mask = self.mask_rectangular_w_tilde
108+
109+
# Bounding box of the rectangular region
110+
ys, xs = np.where(~rect_mask)
111+
y_min, y_max = ys.min(), ys.max()
112+
x_min, x_max = xs.min(), xs.max()
113+
114+
rect_width = x_max - x_min + 1
115+
116+
# Coordinates of unmasked pixels in the original mask (slim order)
117+
mask_ys, mask_xs = np.where(~mask)
118+
119+
# Convert (y, x) → rectangular flat index
120+
rect_indices = ((mask_ys - y_min) * rect_width + (mask_xs - x_min)).astype(
121+
np.int32
122+
)
123+
124+
return rect_indices

autoarray/fit/fit_interferometer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,38 @@ def noise_normalization(self) -> float:
126126
noise_map=self.noise_map.array,
127127
)
128128

129+
@property
130+
def log_evidence(self) -> float:
131+
"""
132+
Returns the log evidence of the inversion's fit to a dataset, where the log evidence includes a number of terms
133+
which quantify the complexity of an inversion's reconstruction (see the `Inversion` module):
134+
135+
Log Evidence = -0.5*[Chi_Squared_Term + Regularization_Term + Log(Covariance_Regularization_Term) -
136+
Log(Regularization_Matrix_Term) + Noise_Term]
137+
138+
Parameters
139+
----------
140+
chi_squared
141+
The chi-squared term of the inversion's fit to the data.
142+
regularization_term
143+
The regularization term of the inversion, which is the sum of the difference between reconstructed \
144+
flux of every pixel multiplied by the regularization coefficient.
145+
log_curvature_regularization_term
146+
The log of the determinant of the sum of the curvature and regularization matrices.
147+
log_regularization_term
148+
The log of the determinant o the regularization matrix.
149+
noise_normalization
150+
The normalization noise_map-term for the data's noise-map.
151+
"""
152+
if self.inversion is not None:
153+
return fit_util.log_evidence_from(
154+
chi_squared=self.inversion.fast_chi_squared,
155+
regularization_term=self.inversion.regularization_term,
156+
log_curvature_regularization_term=self.inversion.log_det_curvature_reg_matrix_term,
157+
log_regularization_term=self.inversion.log_det_regularization_matrix_term,
158+
noise_normalization=self.noise_normalization,
159+
)
160+
129161
@property
130162
def dirty_image(self) -> Array2D:
131163
return self.transformer.image_from(visibilities=self.data)

autoarray/inversion/inversion/interferometer/abstract.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,31 @@ def mapped_reconstructed_image_dict(
121121
mapped_reconstructed_image_dict[linear_obj] = mapped_reconstructed_image
122122

123123
return mapped_reconstructed_image_dict
124+
125+
@property
126+
def fast_chi_squared(self):
127+
128+
xp = self._xp
129+
130+
chi_squared_term_1 = xp.linalg.multi_dot(
131+
[
132+
self.reconstruction.T, # (M,)
133+
self.curvature_matrix, # (M, M)
134+
self.reconstruction, # (M,)
135+
]
136+
)
137+
138+
chi_squared_term_2 = -2.0 * xp.linalg.multi_dot(
139+
[
140+
self.reconstruction.T, # (M,)
141+
self.data_vector, # (M,)
142+
]
143+
)
144+
145+
chi_squared_term_3 = xp.sum(
146+
self.dataset.data.array.real**2.0 / self.dataset.noise_map.array.real**2.0
147+
) + xp.sum(
148+
self.dataset.data.array.imag**2.0 / self.dataset.noise_map.array.imag**2.0
149+
)
150+
151+
return chi_squared_term_1 + chi_squared_term_2 + chi_squared_term_3

0 commit comments

Comments
 (0)