Skip to content

Commit 15c0e4b

Browse files
Jammy2211Jammy2211
authored andcommitted
direct Fourier transform speed up plus VRAM reduction
1 parent 0b07d5a commit 15c0e4b

File tree

5 files changed

+136
-365
lines changed

5 files changed

+136
-365
lines changed

autoarray/dataset/interferometer/dataset.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def __init__(
2929
uv_wavelengths: np.ndarray,
3030
real_space_mask: Mask2D,
3131
transformer_class=TransformerNUFFT,
32-
dft_preload_transform: bool = True,
3332
w_tilde: Optional[WTildeInterferometer] = None,
3433
):
3534
"""
@@ -76,9 +75,6 @@ def __init__(
7675
transformer_class
7776
The class of the Fourier Transform which maps images from real space to Fourier space visibilities and
7877
the uv-plane.
79-
dft_preload_transform
80-
If True, precomputes and stores the cosine and sine terms for the Fourier transform.
81-
This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
8278
"""
8379
self.real_space_mask = real_space_mask
8480

@@ -94,11 +90,8 @@ def __init__(
9490
self.transformer = transformer_class(
9591
uv_wavelengths=uv_wavelengths,
9692
real_space_mask=real_space_mask,
97-
preload_transform=dft_preload_transform,
9893
)
9994

100-
self.dft_preload_transform = dft_preload_transform
101-
10295
use_w_tilde = True if w_tilde is not None else False
10396

10497
self.grids = GridsDataset(
@@ -121,7 +114,6 @@ def from_fits(
121114
noise_map_hdu=0,
122115
uv_wavelengths_hdu=0,
123116
transformer_class=TransformerNUFFT,
124-
dft_preload_transform: bool = True,
125117
):
126118
"""
127119
Factory for loading the interferometer data_type from .fits files, as well as computing properties like the
@@ -147,10 +139,12 @@ def from_fits(
147139
noise_map=noise_map,
148140
uv_wavelengths=uv_wavelengths,
149141
transformer_class=transformer_class,
150-
dft_preload_transform=dft_preload_transform,
151142
)
152143

153-
def apply_w_tilde(self, curvature_preload=None, batch_size: int = 128):
144+
def apply_w_tilde(self, curvature_preload=None, batch_size: int = 128,
145+
show_progress: bool = False,
146+
show_memory: bool = False,
147+
):
154148
"""
155149
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
156150
given the `uv_wavelengths` (see `inversion.inversion_util`).
@@ -185,6 +179,8 @@ def apply_w_tilde(self, curvature_preload=None, batch_size: int = 128):
185179
uv_wavelengths=self.uv_wavelengths,
186180
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
187181
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
182+
show_memory=show_memory,
183+
show_progress=show_progress,
188184
)
189185

190186
dirty_image = self.transformer.image_from(
@@ -205,8 +201,7 @@ def apply_w_tilde(self, curvature_preload=None, batch_size: int = 128):
205201
data=self.data,
206202
noise_map=self.noise_map,
207203
uv_wavelengths=self.uv_wavelengths,
208-
transformer_class=lambda uv_wavelengths, real_space_mask, preload_transform: self.transformer,
209-
dft_preload_transform=self.dft_preload_transform,
204+
transformer_class=lambda uv_wavelengths, real_space_mask: self.transformer,
210205
w_tilde=w_tilde,
211206
)
212207

autoarray/inversion/inversion/interferometer/mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def data_vector(self) -> np.ndarray:
6868
return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from(
6969
transformed_mapping_matrix=self.operated_mapping_matrix,
7070
visibilities=self.data,
71-
noise_map=np.array(self.noise_map),
71+
noise_map=self.noise_map,
7272
)
7373

7474
@property

autoarray/operators/transformer.py

Lines changed: 84 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def __init__(
3838
self,
3939
uv_wavelengths: np.ndarray,
4040
real_space_mask: Mask2D,
41-
preload_transform: bool = True,
4241
):
4342
"""
4443
A direct Fourier transform (DFT) operator for radio interferometric imaging.
@@ -56,9 +55,6 @@ def __init__(
5655
The (u, v) coordinates in wavelengths of the measured visibilities.
5756
real_space_mask
5857
The real-space mask that defines the image grid and which pixels are valid.
59-
preload_transform
60-
If True, precomputes and stores the cosine and sine terms for the Fourier transform.
61-
This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
6258
6359
Attributes
6460
----------
@@ -86,26 +82,6 @@ def __init__(
8682
self.total_visibilities = uv_wavelengths.shape[0]
8783
self.total_image_pixels = self.real_space_mask.pixels_in_mask
8884

89-
self.preload_transform = preload_transform
90-
91-
if preload_transform:
92-
93-
self.preload_real_transforms = (
94-
transformer_util.preload_real_transforms_from(
95-
grid_radians=np.array(self.grid.array),
96-
uv_wavelengths=self.uv_wavelengths,
97-
)
98-
)
99-
100-
self.preload_imag_transforms = (
101-
transformer_util.preload_imag_transforms_from(
102-
grid_radians=np.array(self.grid.array),
103-
uv_wavelengths=self.uv_wavelengths,
104-
)
105-
)
106-
107-
self.real_space_pixels = self.real_space_mask.pixels_in_mask
108-
10985
# NOTE: This is the scaling factor that needs to be applied to the adjoint operator
11086
self.adjoint_scaling = (2.0 * self.grid.shape_native[0]) * (
11187
2.0 * self.grid.shape_native[1]
@@ -118,8 +94,6 @@ def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
11894
This method transforms the input image into the uv-plane (Fourier space), simulating the
11995
measurements made by an interferometer at specified uv-wavelengths.
12096
121-
If `preload_transform` is True, it uses precomputed sine and cosine terms to accelerate the computation.
122-
12397
Parameters
12498
----------
12599
image
@@ -130,22 +104,15 @@ def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
130104
-------
131105
The complex visibilities resulting from the Fourier transform of the input image.
132106
"""
133-
if self.preload_transform:
134-
visibilities = transformer_util.visibilities_via_preload_from(
135-
image_1d=image.array,
136-
preloaded_reals=self.preload_real_transforms,
137-
preloaded_imags=self.preload_imag_transforms,
138-
xp=xp,
139-
)
140-
else:
141-
visibilities = transformer_util.visibilities_from(
142-
image_1d=image.slim.array,
143-
grid_radians=self.grid.array,
144-
uv_wavelengths=self.uv_wavelengths,
145-
xp=xp,
146-
)
147107

148-
return Visibilities(visibilities=xp.array(visibilities))
108+
visibilities = transformer_util.visibilities_from(
109+
image_1d=image.slim.array,
110+
grid_radians=self.grid.array,
111+
uv_wavelengths=self.uv_wavelengths,
112+
xp=xp,
113+
)
114+
115+
return Visibilities(visibilities=visibilities)
149116

150117
def image_from(
151118
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
@@ -189,8 +156,6 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
189156
(represented by a column of the mapping matrix) is computed individually. The result is a matrix
190157
mapping source pixels directly to visibilities.
191158
192-
If `preload_transform` is True, the computation is accelerated using precomputed sine and cosine terms.
193-
194159
Parameters
195160
----------
196161
mapping_matrix
@@ -201,17 +166,12 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
201166
A 2D complex-valued array of shape (n_visibilities, n_source_pixels) that maps source-plane basis
202167
functions directly to the visibilities.
203168
"""
204-
if self.preload_transform:
205-
return transformer_util.transformed_mapping_matrix_via_preload_from(
206-
mapping_matrix=mapping_matrix,
207-
preloaded_reals=self.preload_real_transforms,
208-
preloaded_imags=self.preload_imag_transforms,
209-
)
210169

211170
return transformer_util.transformed_mapping_matrix_from(
212171
mapping_matrix=mapping_matrix,
213172
grid_radians=self.grid.array,
214173
uv_wavelengths=self.uv_wavelengths,
174+
xp=xp
215175
)
216176

217177

@@ -256,8 +216,6 @@ def __init__(
256216
Index map converting from slim (1D) grid to native (2D) indexing, for image reshaping.
257217
shift : np.ndarray
258218
Complex exponential phase shift applied to account for real-space pixel centering.
259-
real_space_pixels : int
260-
Total number of valid real-space pixels defined by the mask.
261219
total_visibilities : int
262220
Total number of visibilities across all uv-wavelength components.
263221
adjoint_scaling : float
@@ -298,8 +256,6 @@ def __init__(
298256
)
299257
)
300258

301-
self.real_space_pixels = self.real_space_mask.pixels_in_mask
302-
303259
# NOTE: If reshaped the shape of the operator is (2 x Nvis, Np) else it is (Nvis, Np)
304260
self.total_visibilities = int(uv_wavelengths.shape[0] * uv_wavelengths.shape[1])
305261

@@ -362,33 +318,73 @@ def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6
362318
Jd=interp_kernel,
363319
)
364320

365-
def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
321+
def _pynufft_forward_numpy(self, image_np: np.ndarray) -> np.ndarray:
366322
"""
367-
Computes visibilities from a real-space image using the NUFFT forward transform.
323+
NumPy-only forward NUFFT. Runs on host.
324+
"""
325+
warnings.filterwarnings("ignore")
368326

369-
Parameters
370-
----------
371-
image
372-
The input image in real space, represented as a 2D array object.
327+
# Flip vertically (PyNUFFT internal convention)
328+
image_np = image_np[::-1, :]
373329

374-
Returns
375-
-------
376-
The complex visibilities in the uv-plane computed via the NUFFT forward operation.
330+
# PyNUFFT forward
331+
vis = self.forward(image_np)
377332

378-
Notes
379-
-----
380-
- The image is flipped vertically before transformation to account for PyNUFFT’s internal data layout.
381-
- Warnings during the NUFFT computation are suppressed for cleaner output.
333+
return vis
334+
335+
def visibilities_from_jax(self, image: np.ndarray) -> np.ndarray:
336+
"""
337+
JAX-compatible wrapper around PyNUFFT forward.
338+
Can be used inside jax.jit.
382339
"""
383340

384-
warnings.filterwarnings("ignore")
341+
import jax
342+
import jax.numpy as jnp
343+
from jax import ShapeDtypeStruct
344+
345+
# You MUST tell JAX the output shape & dtype
385346

386-
return Visibilities(
387-
visibilities=self.forward(
388-
image.native.array[::-1, :]
389-
) # flip due to PyNUFFT internal flip
347+
out_shape = (self.total_visibilities // 2,) # example
348+
out_dtype = jnp.complex128
349+
350+
result_shape = ShapeDtypeStruct(
351+
shape=out_shape,
352+
dtype=out_dtype,
390353
)
391354

355+
return jax.pure_callback(
356+
lambda img: self._pynufft_forward_numpy(img),
357+
result_shape,
358+
image,
359+
)
360+
361+
def visibilities_from(self, image, xp=np):
362+
363+
# start with native image padded with zeros
364+
image_native = xp.zeros(image.mask.shape, dtype=image.dtype)
365+
366+
if xp.__name__.startswith("jax"):
367+
368+
image_native = image_native.at[image.mask.slim_to_native_tuple].set(
369+
image.array
370+
)
371+
372+
else:
373+
374+
image_native[image.mask.slim_to_native_tuple] = image.array
375+
376+
if xp is np:
377+
warnings.filterwarnings("ignore")
378+
return Visibilities(
379+
visibilities=self.forward(image_native[::-1, :])
380+
)
381+
382+
else:
383+
384+
vis = self.visibilities_from_jax(image_native)
385+
386+
return Visibilities(visibilities=vis)
387+
392388
def image_from(
393389
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
394390
) -> Array2D:
@@ -446,16 +442,27 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
446442

447443
for source_pixel_1d_index in range(mapping_matrix.shape[1]):
448444

449-
image_2d = array_2d_util.array_2d_native_from(
450-
array_2d_slim=mapping_matrix[:, source_pixel_1d_index],
451-
mask_2d=self.grid.mask,
452-
xp=xp,
453-
)
445+
image_2d = xp.zeros(self.grid.shape_native, dtype=mapping_matrix.dtype)
446+
447+
if xp.__name__.startswith("jax"):
448+
449+
image_2d = image_2d.at[self.grid.mask.slim_to_native_tuple].set(
450+
mapping_matrix[:, source_pixel_1d_index]
451+
)
452+
453+
else:
454+
455+
image_2d[self.grid.mask.slim_to_native_tuple] = mapping_matrix[
456+
:, source_pixel_1d_index
457+
]
454458

455459
image = Array2D(values=image_2d, mask=self.grid.mask)
456460

457461
visibilities = self.visibilities_from(image=image, xp=xp)
458462

459-
transformed_mapping_matrix[:, source_pixel_1d_index] = visibilities
463+
if xp.__name__.startswith("jax"):
464+
transformed_mapping_matrix = transformed_mapping_matrix.at[:, source_pixel_1d_index].set(visibilities.array)
465+
else:
466+
transformed_mapping_matrix[:, source_pixel_1d_index] = visibilities.array
460467

461468
return transformed_mapping_matrix

0 commit comments

Comments
 (0)