@@ -38,7 +38,6 @@ def __init__(
3838 self ,
3939 uv_wavelengths : np .ndarray ,
4040 real_space_mask : Mask2D ,
41- preload_transform : bool = True ,
4241 ):
4342 """
4443 A direct Fourier transform (DFT) operator for radio interferometric imaging.
@@ -56,9 +55,6 @@ def __init__(
5655 The (u, v) coordinates in wavelengths of the measured visibilities.
5756 real_space_mask
5857 The real-space mask that defines the image grid and which pixels are valid.
59- preload_transform
60- If True, precomputes and stores the cosine and sine terms for the Fourier transform.
61- This accelerates repeated transforms but consumes additional memory (~1GB+ for large datasets).
6258
6359 Attributes
6460 ----------
@@ -86,26 +82,6 @@ def __init__(
8682 self .total_visibilities = uv_wavelengths .shape [0 ]
8783 self .total_image_pixels = self .real_space_mask .pixels_in_mask
8884
89- self .preload_transform = preload_transform
90-
91- if preload_transform :
92-
93- self .preload_real_transforms = (
94- transformer_util .preload_real_transforms_from (
95- grid_radians = np .array (self .grid .array ),
96- uv_wavelengths = self .uv_wavelengths ,
97- )
98- )
99-
100- self .preload_imag_transforms = (
101- transformer_util .preload_imag_transforms_from (
102- grid_radians = np .array (self .grid .array ),
103- uv_wavelengths = self .uv_wavelengths ,
104- )
105- )
106-
107- self .real_space_pixels = self .real_space_mask .pixels_in_mask
108-
10985 # NOTE: This is the scaling factor that needs to be applied to the adjoint operator
11086 self .adjoint_scaling = (2.0 * self .grid .shape_native [0 ]) * (
11187 2.0 * self .grid .shape_native [1 ]
@@ -118,8 +94,6 @@ def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
11894 This method transforms the input image into the uv-plane (Fourier space), simulating the
11995 measurements made by an interferometer at specified uv-wavelengths.
12096
121- If `preload_transform` is True, it uses precomputed sine and cosine terms to accelerate the computation.
122-
12397 Parameters
12498 ----------
12599 image
@@ -130,22 +104,15 @@ def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
130104 -------
131105 The complex visibilities resulting from the Fourier transform of the input image.
132106 """
133- if self .preload_transform :
134- visibilities = transformer_util .visibilities_via_preload_from (
135- image_1d = image .array ,
136- preloaded_reals = self .preload_real_transforms ,
137- preloaded_imags = self .preload_imag_transforms ,
138- xp = xp ,
139- )
140- else :
141- visibilities = transformer_util .visibilities_from (
142- image_1d = image .slim .array ,
143- grid_radians = self .grid .array ,
144- uv_wavelengths = self .uv_wavelengths ,
145- xp = xp ,
146- )
147107
148- return Visibilities (visibilities = xp .array (visibilities ))
108+ visibilities = transformer_util .visibilities_from (
109+ image_1d = image .slim .array ,
110+ grid_radians = self .grid .array ,
111+ uv_wavelengths = self .uv_wavelengths ,
112+ xp = xp ,
113+ )
114+
115+ return Visibilities (visibilities = visibilities )
149116
150117 def image_from (
151118 self , visibilities : Visibilities , use_adjoint_scaling : bool = False , xp = np
@@ -189,8 +156,6 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
189156 (represented by a column of the mapping matrix) is computed individually. The result is a matrix
190157 mapping source pixels directly to visibilities.
191158
192- If `preload_transform` is True, the computation is accelerated using precomputed sine and cosine terms.
193-
194159 Parameters
195160 ----------
196161 mapping_matrix
@@ -201,17 +166,12 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
201166 A 2D complex-valued array of shape (n_visibilities, n_source_pixels) that maps source-plane basis
202167 functions directly to the visibilities.
203168 """
204- if self .preload_transform :
205- return transformer_util .transformed_mapping_matrix_via_preload_from (
206- mapping_matrix = mapping_matrix ,
207- preloaded_reals = self .preload_real_transforms ,
208- preloaded_imags = self .preload_imag_transforms ,
209- )
210169
211170 return transformer_util .transformed_mapping_matrix_from (
212171 mapping_matrix = mapping_matrix ,
213172 grid_radians = self .grid .array ,
214173 uv_wavelengths = self .uv_wavelengths ,
174+ xp = xp
215175 )
216176
217177
@@ -256,8 +216,6 @@ def __init__(
256216 Index map converting from slim (1D) grid to native (2D) indexing, for image reshaping.
257217 shift : np.ndarray
258218 Complex exponential phase shift applied to account for real-space pixel centering.
259- real_space_pixels : int
260- Total number of valid real-space pixels defined by the mask.
261219 total_visibilities : int
262220 Total number of visibilities across all uv-wavelength components.
263221 adjoint_scaling : float
@@ -298,8 +256,6 @@ def __init__(
298256 )
299257 )
300258
301- self .real_space_pixels = self .real_space_mask .pixels_in_mask
302-
303259 # NOTE: If reshaped the shape of the operator is (2 x Nvis, Np) else it is (Nvis, Np)
304260 self .total_visibilities = int (uv_wavelengths .shape [0 ] * uv_wavelengths .shape [1 ])
305261
@@ -362,33 +318,73 @@ def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6
362318 Jd = interp_kernel ,
363319 )
364320
365- def visibilities_from (self , image : Array2D , xp = np ) -> Visibilities :
321+ def _pynufft_forward_numpy (self , image_np : np . ndarray ) -> np . ndarray :
366322 """
367- Computes visibilities from a real-space image using the NUFFT forward transform.
323+ NumPy-only forward NUFFT. Runs on host.
324+ """
325+ warnings .filterwarnings ("ignore" )
368326
369- Parameters
370- ----------
371- image
372- The input image in real space, represented as a 2D array object.
327+ # Flip vertically (PyNUFFT internal convention)
328+ image_np = image_np [::- 1 , :]
373329
374- Returns
375- -------
376- The complex visibilities in the uv-plane computed via the NUFFT forward operation.
330+ # PyNUFFT forward
331+ vis = self .forward (image_np )
377332
378- Notes
379- -----
380- - The image is flipped vertically before transformation to account for PyNUFFT’s internal data layout.
381- - Warnings during the NUFFT computation are suppressed for cleaner output.
333+ return vis
334+
335+ def visibilities_from_jax (self , image : np .ndarray ) -> np .ndarray :
336+ """
337+ JAX-compatible wrapper around PyNUFFT forward.
338+ Can be used inside jax.jit.
382339 """
383340
384- warnings .filterwarnings ("ignore" )
341+ import jax
342+ import jax .numpy as jnp
343+ from jax import ShapeDtypeStruct
344+
345+ # You MUST tell JAX the output shape & dtype
385346
386- return Visibilities (
387- visibilities = self .forward (
388- image .native .array [::- 1 , :]
389- ) # flip due to PyNUFFT internal flip
347+ out_shape = (self .total_visibilities // 2 ,) # example
348+ out_dtype = jnp .complex128
349+
350+ result_shape = ShapeDtypeStruct (
351+ shape = out_shape ,
352+ dtype = out_dtype ,
390353 )
391354
355+ return jax .pure_callback (
356+ lambda img : self ._pynufft_forward_numpy (img ),
357+ result_shape ,
358+ image ,
359+ )
360+
361+ def visibilities_from (self , image , xp = np ):
362+
363+ # start with native image padded with zeros
364+ image_native = xp .zeros (image .mask .shape , dtype = image .dtype )
365+
366+ if xp .__name__ .startswith ("jax" ):
367+
368+ image_native = image_native .at [image .mask .slim_to_native_tuple ].set (
369+ image .array
370+ )
371+
372+ else :
373+
374+ image_native [image .mask .slim_to_native_tuple ] = image .array
375+
376+ if xp is np :
377+ warnings .filterwarnings ("ignore" )
378+ return Visibilities (
379+ visibilities = self .forward (image_native [::- 1 , :])
380+ )
381+
382+ else :
383+
384+ vis = self .visibilities_from_jax (image_native )
385+
386+ return Visibilities (visibilities = vis )
387+
392388 def image_from (
393389 self , visibilities : Visibilities , use_adjoint_scaling : bool = False , xp = np
394390 ) -> Array2D :
@@ -446,16 +442,27 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
446442
447443 for source_pixel_1d_index in range (mapping_matrix .shape [1 ]):
448444
449- image_2d = array_2d_util .array_2d_native_from (
450- array_2d_slim = mapping_matrix [:, source_pixel_1d_index ],
451- mask_2d = self .grid .mask ,
452- xp = xp ,
453- )
445+ image_2d = xp .zeros (self .grid .shape_native , dtype = mapping_matrix .dtype )
446+
447+ if xp .__name__ .startswith ("jax" ):
448+
449+ image_2d = image_2d .at [self .grid .mask .slim_to_native_tuple ].set (
450+ mapping_matrix [:, source_pixel_1d_index ]
451+ )
452+
453+ else :
454+
455+ image_2d [self .grid .mask .slim_to_native_tuple ] = mapping_matrix [
456+ :, source_pixel_1d_index
457+ ]
454458
455459 image = Array2D (values = image_2d , mask = self .grid .mask )
456460
457461 visibilities = self .visibilities_from (image = image , xp = xp )
458462
459- transformed_mapping_matrix [:, source_pixel_1d_index ] = visibilities
463+ if xp .__name__ .startswith ("jax" ):
464+ transformed_mapping_matrix = transformed_mapping_matrix .at [:, source_pixel_1d_index ].set (visibilities .array )
465+ else :
466+ transformed_mapping_matrix [:, source_pixel_1d_index ] = visibilities .array
460467
461468 return transformed_mapping_matrix
0 commit comments