Skip to content

Commit 0b07d5a

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent fc6af48 commit 0b07d5a

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

autoarray/dataset/interferometer/dataset.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def from_fits(
150150
dft_preload_transform=dft_preload_transform,
151151
)
152152

153-
def apply_w_tilde(self):
153+
def apply_w_tilde(self, curvature_preload=None, batch_size: int = 128):
154154
"""
155155
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
156156
given the `uv_wavelengths` (see `inversion.inversion_util`).
@@ -161,20 +161,31 @@ def apply_w_tilde(self):
161161
This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
162162
ensuring efficient set up of the `Interferometer` class.
163163
164+
Parameters
165+
----------
166+
curvature_preload
167+
An already computed curvature preload matrix for this dataset (e.g. loaded from hard-disk), to prevent
168+
long recalculations of this matrix for large datasets.
169+
batch_size
170+
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
171+
which can be reduced to produce lower memory usage at the cost of speed.
172+
164173
Returns
165174
-------
166175
WTildeInterferometer
167176
Precomputed values used for the w tilde formalism of linear algebra calculations.
168177
"""
169178

170-
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
179+
if curvature_preload is None:
171180

172-
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
173-
noise_map_real=self.noise_map.array.real,
174-
uv_wavelengths=self.uv_wavelengths,
175-
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
176-
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
177-
)
181+
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
182+
183+
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
184+
noise_map_real=self.noise_map.array.real,
185+
uv_wavelengths=self.uv_wavelengths,
186+
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
187+
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
188+
)
178189

179190
dirty_image = self.transformer.image_from(
180191
visibilities=self.data.real * self.noise_map.real**-2.0
@@ -186,6 +197,7 @@ def apply_w_tilde(self):
186197
curvature_preload=curvature_preload,
187198
dirty_image=dirty_image.array,
188199
real_space_mask=self.real_space_mask,
200+
batch_size=batch_size,
189201
)
190202

191203
return Interferometer(

autoarray/dataset/interferometer/w_tilde.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def __init__(
1010
curvature_preload: np.ndarray,
1111
dirty_image: np.ndarray,
1212
real_space_mask: Mask2D,
13+
batch_size: int = 128,
1314
):
1415
"""
1516
Packages together all derived data quantities necessary to fit `Interferometer` data using an ` Inversion` via
@@ -33,6 +34,9 @@ def __init__(
3334
real_space_mask
3435
The 2D mask in real-space defining the area where the interferometer data's visibilities are observing
3536
a signal.
37+
batch_size
38+
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
39+
which can be reduced to produce lower memory usage at the cost of speed.
3640
"""
3741
super().__init__(
3842
curvature_preload=curvature_preload,
@@ -46,7 +50,7 @@ def __init__(
4650
)
4751

4852
self.fft_state = inversion_interferometer_util.w_tilde_fft_state_from(
49-
curvature_preload=self.curvature_preload, batch_size=450
53+
curvature_preload=self.curvature_preload, batch_size=batch_size
5054
)
5155

5256
@property

autoarray/inversion/inversion/interferometer/mapping.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def curvature_matrix(self) -> np.ndarray:
107107
xp=self._xp,
108108
)
109109

110-
print(curvature_matrix)
111-
112110
return curvature_matrix
113111

114112
@property

0 commit comments

Comments
 (0)