@@ -39,7 +39,6 @@ def __init__(
3939 uv_wavelengths : np .ndarray ,
4040 real_space_mask : Mask2D ,
4141 preload_transform : bool = True ,
42- xp = np ,
4342 ):
4443 """
4544 A direct Fourier transform (DFT) operator for radio interferometric imaging.
@@ -112,9 +111,7 @@ def __init__(
112111 2.0 * self .grid .shape_native [1 ]
113112 )
114113
115- self ._xp = xp
116-
117- def visibilities_from (self , image : Array2D ) -> Visibilities :
114+ def visibilities_from (self , image : Array2D , xp = np ) -> Visibilities :
118115 """
119116 Computes the visibilities from a real-space image using the direct Fourier transform (DFT).
120117
@@ -138,19 +135,20 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
138135 image_1d = image .array ,
139136 preloaded_reals = self .preload_real_transforms ,
140137 preloaded_imags = self .preload_imag_transforms ,
141- xp = self . _xp ,
138+ xp = xp ,
142139 )
143140 else :
144141 visibilities = transformer_util .visibilities_from (
145142 image_1d = image .slim .array ,
146143 grid_radians = self .grid .array ,
147144 uv_wavelengths = self .uv_wavelengths ,
145+ xp = xp
148146 )
149147
150- return Visibilities (visibilities = self . _xp .array (visibilities ))
148+ return Visibilities (visibilities = xp .array (visibilities ))
151149
152150 def image_from (
153- self , visibilities : Visibilities , use_adjoint_scaling : bool = False
151+ self , visibilities : Visibilities , use_adjoint_scaling : bool = False , xp = np
154152 ) -> Array2D :
155153 """
156154 Computes the real-space image from a set of visibilities using the adjoint of the DFT.
@@ -178,12 +176,12 @@ def image_from(
178176 )
179177
180178 image_native = array_2d_util .array_2d_native_from (
181- array_2d_slim = image_slim , mask_2d = self .real_space_mask , xp = self . _xp
179+ array_2d_slim = image_slim , mask_2d = self .real_space_mask , xp = xp
182180 )
183181
184182 return Array2D (values = image_native , mask = self .real_space_mask )
185183
186- def transform_mapping_matrix (self , mapping_matrix : np .ndarray ) -> np .ndarray :
184+ def transform_mapping_matrix (self , mapping_matrix : np .ndarray , xp = np ) -> np .ndarray :
187185 """
188186 Applies the DFT to a mapping matrix that maps source pixels to image pixels.
189187
@@ -310,8 +308,6 @@ def __init__(
310308 2.0 * self .grid .shape_native [1 ]
311309 )
312310
313- self ._xp = xp
314-
315311 def initialize_plan (self , ratio : int = 2 , interp_kernel : Tuple [int , int ] = (6 , 6 )):
316312 """
317313 Initializes the PyNUFFT plan for performing the NUFFT operation.
@@ -394,7 +390,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
394390 )
395391
396392 def image_from (
397- self , visibilities : Visibilities , use_adjoint_scaling : bool = False
393+ self , visibilities : Visibilities , use_adjoint_scaling : bool = False , xp = np
398394 ) -> Array2D :
399395 """
400396 Reconstructs a real-space image from visibilities using the NUFFT adjoint transform.
@@ -425,24 +421,24 @@ def image_from(
425421
426422 return Array2D (values = image , mask = self .real_space_mask )
427423
428- def transform_mapping_matrix (self , mapping_matrix : np .ndarray ) -> np .ndarray :
424+ def transform_mapping_matrix (self , mapping_matrix : np .ndarray , xp = np ) -> np .ndarray :
429425 """
430- Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
426+ Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
431427
432- Parameters
433- ----------
434- mapping_matrix
435- A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
428+ Parameters
429+ ----------
430+ mapping_matrix
431+ A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
436432
437- Returns
433+ Returns
438434 -------
439- A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
440- in the input mapping matrix.
435+ A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
436+ in the input mapping matrix.
441437
442- Notes
443- -----
444- - Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
445- - This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
438+ Notes
439+ -----
440+ - Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
441+ - This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
446442 """
447443 transformed_mapping_matrix = 0 + 0j * np .zeros (
448444 (self .uv_wavelengths .shape [0 ], mapping_matrix .shape [1 ])
@@ -452,7 +448,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
452448 image_2d = array_2d_util .array_2d_native_from (
453449 array_2d_slim = mapping_matrix [:, source_pixel_1d_index ],
454450 mask_2d = self .grid .mask ,
455- xp = self . _xp ,
451+ xp = xp ,
456452 )
457453
458454 image = Array2D (values = image_2d , mask = self .grid .mask )
0 commit comments