11import logging
22import numpy as np
3- from pathlib import Path
43from typing import Optional
54
6- from autoconf import cached_property
7-
85from autoconf .fitsable import ndarray_via_fits_from , output_to_fits
96
107from autoarray .dataset .abstract .dataset import AbstractDataset
118from autoarray .dataset .interferometer .w_tilde import WTildeInterferometer
129from autoarray .dataset .grids import GridsDataset
10+ from autoarray .operators .transformer import TransformerDFT
1311from autoarray .operators .transformer import TransformerNUFFT
1412from autoarray .mask .mask_2d import Mask2D
1513from autoarray .structures .visibilities import Visibilities
1614from autoarray .structures .visibilities import VisibilitiesNoiseMap
1715
18- from autoarray .inversion .inversion .interferometer import inversion_interferometer_util
19-
2016from autoarray import exc
17+ from autoarray .inversion .inversion .interferometer import (
18+ inversion_interferometer_util ,
19+ )
2120
2221logger = logging .getLogger (__name__ )
2322
@@ -30,8 +29,8 @@ def __init__(
3029 uv_wavelengths : np .ndarray ,
3130 real_space_mask : Mask2D ,
3231 transformer_class = TransformerNUFFT ,
33- dft_preload_transform : bool = True ,
3432 w_tilde : Optional [WTildeInterferometer ] = None ,
33+ raise_error_dft_visibilities_limit : bool = True ,
3534 ):
3635 """
3736 An interferometer dataset, containing the visibilities data, noise-map, real-space msk, Fourier transformer and
@@ -77,9 +76,6 @@ def __init__(
7776 transformer_class
7877 The class of the Fourier Transform which maps images from real space to Fourier space visibilities and
7978 the uv-plane.
80- dft_preload_transform
81- If True, precomputes and stores the cosine and sine terms for the Fourier transform.
82- This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
8379 """
8480 self .real_space_mask = real_space_mask
8581
@@ -95,11 +91,8 @@ def __init__(
9591 self .transformer = transformer_class (
9692 uv_wavelengths = uv_wavelengths ,
9793 real_space_mask = real_space_mask ,
98- preload_transform = dft_preload_transform ,
9994 )
10095
101- self .dft_preload_transform = dft_preload_transform
102-
10396 use_w_tilde = True if w_tilde is not None else False
10497
10598 self .grids = GridsDataset (
@@ -111,6 +104,22 @@ def __init__(
111104
112105 self .w_tilde = w_tilde
113106
107+ if raise_error_dft_visibilities_limit :
108+ if (
109+ self .uv_wavelengths .shape [0 ] > 10000
110+ and transformer_class == TransformerDFT
111+ ):
112+ raise exc .DatasetException (
113+ """
114+ Interferometer datasets with more than 10,000 visibilities should use the TransformerNUFFT class for
115+ efficient Fourier transforms between real and uv-space. The DFT (Discrete Fourier Transform) is too slow for
116+ large datasets.
117+
118+ If you are certain you want to use the TransformerDFT class, you can disable this error by passing
119+ the input `raise_error_dft_visibilities_limit=False` when loading the Interferometer dataset.
120+ """
121+ )
122+
114123 @classmethod
115124 def from_fits (
116125 cls ,
@@ -122,7 +131,6 @@ def from_fits(
122131 noise_map_hdu = 0 ,
123132 uv_wavelengths_hdu = 0 ,
124133 transformer_class = TransformerNUFFT ,
125- dft_preload_transform : bool = True ,
126134 ):
127135 """
128136 Factory for loading the interferometer data_type from .fits files, as well as computing properties like the
@@ -148,10 +156,15 @@ def from_fits(
148156 noise_map = noise_map ,
149157 uv_wavelengths = uv_wavelengths ,
150158 transformer_class = transformer_class ,
151- dft_preload_transform = dft_preload_transform ,
152159 )
153160
154- def apply_w_tilde (self ):
161+ def apply_w_tilde (
162+ self ,
163+ curvature_preload = None ,
164+ batch_size : int = 128 ,
165+ show_progress : bool = False ,
166+ show_memory : bool = False ,
167+ ):
155168 """
156169 The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
157170 given the `uv_wavelengths` (see `inversion.inversion_util`).
@@ -162,44 +175,33 @@ def apply_w_tilde(self):
162175 This uses lazy allocation such that the calculation is only performed when the wtilde matrices are used,
163176 ensuring efficient set up of the `Interferometer` class.
164177
178+ Parameters
179+ ----------
180+ curvature_preload
181+ An already computed curvature preload matrix for this dataset (e.g. loaded from hard-disk), to prevent
182+ long recalculations of this matrix for large datasets.
183+ batch_size
184+ The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
185+ which can be reduced to produce lower memory usage at the cost of speed.
186+
165187 Returns
166188 -------
167189 WTildeInterferometer
168190 Precomputed values used for the w tilde formalism of linear algebra calculations.
169191 """
170192
171- logger .info ("INTERFEROMETER - Computing W-Tilde... May take a moment." )
172-
173- try :
174- import numba
175- except ModuleNotFoundError :
176- raise exc .InversionException (
177- "Inversion w-tilde functionality (pixelized reconstructions) is "
178- "disabled if numba is not installed.\n \n "
179- "This is because the run-times without numba are too slow.\n \n "
180- "Please install numba, which is described at the following web page:\n \n "
181- "https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
182- )
193+ if curvature_preload is None :
183194
184- curvature_preload = (
185- inversion_interferometer_util .w_tilde_curvature_preload_interferometer_from (
186- noise_map_real = np .array (self .noise_map .real ),
187- uv_wavelengths = np .array (self .uv_wavelengths ),
188- shape_masked_pixels_2d = np .array (
189- self .transformer .grid .mask .shape_native_masked_pixels
190- ),
191- grid_radians_2d = np .array (
192- self .transformer .grid .mask .derive_grid .all_false .in_radians .native
193- ),
194- )
195- )
195+ logger .info ("INTERFEROMETER - Computing W-Tilde... May take a moment." )
196196
197- w_matrix = inversion_interferometer_util .w_tilde_via_preload_from (
198- w_tilde_preload = curvature_preload ,
199- native_index_for_slim_index = np .array (
200- self .real_space_mask .derive_indexes .native_for_slim
201- ).astype ("int" ),
202- )
197+ curvature_preload = inversion_interferometer_util .w_tilde_curvature_preload_interferometer_from (
198+ noise_map_real = self .noise_map .array .real ,
199+ uv_wavelengths = self .uv_wavelengths ,
200+ shape_masked_pixels_2d = self .transformer .grid .mask .shape_native_masked_pixels ,
201+ grid_radians_2d = self .transformer .grid .mask .derive_grid .all_false .in_radians .native .array ,
202+ show_memory = show_memory ,
203+ show_progress = show_progress ,
204+ )
203205
204206 dirty_image = self .transformer .image_from (
205207 visibilities = self .data .real * self .noise_map .real ** - 2.0
@@ -208,19 +210,18 @@ def apply_w_tilde(self):
208210 )
209211
210212 w_tilde = WTildeInterferometer (
211- w_matrix = w_matrix ,
212213 curvature_preload = curvature_preload ,
213- dirty_image = np . array ( dirty_image .array ) ,
214+ dirty_image = dirty_image .array ,
214215 real_space_mask = self .real_space_mask ,
216+ batch_size = batch_size ,
215217 )
216218
217219 return Interferometer (
218220 real_space_mask = self .real_space_mask ,
219221 data = self .data ,
220222 noise_map = self .noise_map ,
221223 uv_wavelengths = self .uv_wavelengths ,
222- transformer_class = lambda uv_wavelengths , real_space_mask , preload_transform : self .transformer ,
223- dft_preload_transform = self .dft_preload_transform ,
224+ transformer_class = lambda uv_wavelengths , real_space_mask : self .transformer ,
224225 w_tilde = w_tilde ,
225226 )
226227
0 commit comments