@@ -150,7 +150,7 @@ def from_fits(
150150 dft_preload_transform = dft_preload_transform ,
151151 )
152152
153- def apply_w_tilde (self ):
153+ def apply_w_tilde (self , curvature_preload = None , batch_size : int = 128 ):
154154 """
155155 The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
156156 given the `uv_wavelengths` (see `inversion.inversion_util`).
@@ -161,20 +161,31 @@ def apply_w_tilde(self):
161161 This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
162162 ensuring efficient set up of the `Interferometer` class.
163163
164+ Parameters
165+ ----------
166+ curvature_preload
167+ An already computed curvature preload matrix for this dataset (e.g. loaded from hard-disk), to prevent
168+ long recalculations of this matrix for large datasets.
169+ batch_size
170+ The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
171+ which can be reduced to produce lower memory usage at the cost of speed.
172+
164173 Returns
165174 -------
166175 WTildeInterferometer
167176 Precomputed values used for the w tilde formalism of linear algebra calculations.
168177 """
169178
170- logger . info ( "INTERFEROMETER - Computing W-Tilde... May take a moment." )
179+ if curvature_preload is None :
171180
172- curvature_preload = inversion_interferometer_util .w_tilde_curvature_preload_interferometer_from (
173- noise_map_real = self .noise_map .array .real ,
174- uv_wavelengths = self .uv_wavelengths ,
175- shape_masked_pixels_2d = self .transformer .grid .mask .shape_native_masked_pixels ,
176- grid_radians_2d = self .transformer .grid .mask .derive_grid .all_false .in_radians .native .array ,
177- )
181+ logger .info ("INTERFEROMETER - Computing W-Tilde... May take a moment." )
182+
183+ curvature_preload = inversion_interferometer_util .w_tilde_curvature_preload_interferometer_from (
184+ noise_map_real = self .noise_map .array .real ,
185+ uv_wavelengths = self .uv_wavelengths ,
186+ shape_masked_pixels_2d = self .transformer .grid .mask .shape_native_masked_pixels ,
187+ grid_radians_2d = self .transformer .grid .mask .derive_grid .all_false .in_radians .native .array ,
188+ )
178189
179190 dirty_image = self .transformer .image_from (
180191 visibilities = self .data .real * self .noise_map .real ** - 2.0
@@ -186,6 +197,7 @@ def apply_w_tilde(self):
186197 curvature_preload = curvature_preload ,
187198 dirty_image = dirty_image .array ,
188199 real_space_mask = self .real_space_mask ,
200+ batch_size = batch_size ,
189201 )
190202
191203 return Interferometer (
0 commit comments