Skip to content

Commit a5e3e78

Browse files
authored
Merge pull request #193 from Jammy2211/feature/xp_no_autofit_import
Feature/xp no autofit import
2 parents 9fe34bc + 314e2d0 commit a5e3e78

File tree

20 files changed

+797
-175
lines changed

20 files changed

+797
-175
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from autoconf import jax_wrapper
12
from autoconf.dictable import register_parser
23
from autoconf import conf
34

autoarray/abstract_ndarray.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
from abc import ABC
66
from abc import abstractmethod
7-
import jax.numpy as jnp
8-
from jax._src.tree_util import register_pytree_node
97

108
import numpy as np
119

@@ -75,20 +73,20 @@ def __init__(self, array, xp=np):
7573
while isinstance(array, AbstractNDArray):
7674
array = array.array
7775
self._array = array
78-
try:
79-
register_pytree_node(
80-
type(self),
81-
self.instance_flatten,
82-
self.instance_unflatten,
83-
)
84-
except ValueError:
85-
pass
76+
# try:
77+
# register_pytree_node(
78+
# type(self),
79+
# self.instance_flatten,
80+
# self.instance_unflatten,
81+
# )
82+
# except ValueError:
83+
# pass
8684

8785
self._xp = xp
8886

8987
def invert(self):
9088
new = self.copy()
91-
new._array = jnp.invert(new._array)
89+
new._array = self._xp.invert(new._array)
9290
return new
9391

9492
@classmethod
@@ -117,7 +115,7 @@ def instance_unflatten(cls, aux_data, children):
117115
setattr(instance, key, value)
118116
return instance
119117

120-
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
118+
def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
121119
"""
122120
Copy this object but give it a new array.
123121
@@ -137,10 +135,9 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
137135
new_array._array = array
138136
return new_array
139137

140-
@staticmethod
141-
def flip_hdu_for_ds9(values):
138+
def flip_hdu_for_ds9(self, values):
142139
if conf.instance["general"]["fits"]["flip_for_ds9"]:
143-
return jnp.flipud(values)
140+
return self._xp.flipud(values)
144141
return values
145142

146143
def copy(self):
@@ -170,7 +167,7 @@ def __iter__(self):
170167

171168
@to_new_array
172169
def sqrt(self):
173-
return jnp.sqrt(self._array)
170+
return self._xp.sqrt(self._array)
174171

175172
@property
176173
def array(self):
@@ -333,20 +330,28 @@ def __getattr__(self, item):
333330
)
334331

335332
def __getitem__(self, item):
333+
336334
result = self._array[item]
335+
337336
if isinstance(item, slice):
338337
result = self.with_new_array(result)
339-
if isinstance(result, jnp.ndarray):
340-
result = self.with_new_array(result)
338+
339+
try:
340+
import jax.numpy as jnp
341+
if isinstance(result, jnp.ndarray):
342+
result = self.with_new_array(result)
343+
except ImportError:
344+
pass
345+
341346
return result
342347

343348
def __setitem__(self, key, value):
344-
from jax import Array
345349

346-
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
347-
self._array = jnp.where(key, value, self._array)
348-
else:
350+
if isinstance(self._array, np.ndarray):
349351
self._array[key] = value
352+
else:
353+
import jax.numpy as jnp
354+
self._array = jnp.where(key, value, self._array)
350355

351356
def __repr__(self):
352357
return repr(self._array).replace(

autoarray/config/general.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
jax:
2-
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
31
fits:
42
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
53
psf:

autoarray/fit/fit_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def subtracted_from(grid, offset):
163163
if grid is None:
164164
return None
165165

166-
return grid.subtracted_from(offset=offset)
166+
return grid.subtracted_from(offset=offset, xp=self._xp)
167167

168168
lp = subtracted_from(
169169
grid=self.dataset.grids.lp, offset=self.dataset_model.grid_offset

autoarray/inversion/inversion/interferometer/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
6565
"""
6666
return [
6767
self.transformer.transform_mapping_matrix(
68-
mapping_matrix=linear_obj.mapping_matrix
68+
mapping_matrix=linear_obj.mapping_matrix, xp=self._xp
6969
)
7070
for linear_obj in self.linear_obj_list
7171
]

autoarray/mask/derive/indexes_2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import numpy as np
44

5-
from jax._src.tree_util import register_pytree_node_class
65
from typing import TYPE_CHECKING
76

87
if TYPE_CHECKING:
@@ -14,7 +13,6 @@
1413
logger = logging.getLogger(__name__)
1514

1615

17-
@register_pytree_node_class
1816
class DeriveIndexes2D:
1917

2018
def __init__(self, mask: Mask2D, xp=np):

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22

3-
from jax._src.tree_util import register_pytree_node_class
43
from typing import Union
54

65
from autoconf import conf
@@ -11,7 +10,6 @@
1110
from autoarray.operators.over_sampling import over_sample_util
1211

1312

14-
@register_pytree_node_class
1513
class OverSampler:
1614
def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
1715
"""
@@ -229,6 +227,7 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
229227
Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second
230228
sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth.
231229
"""
230+
232231
if conf.instance["general"]["structures"]["native_binned_only"]:
233232
return self
234233

@@ -245,16 +244,28 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
245244

246245
else:
247246

248-
import jax
247+
if xp.__name__.startswith("jax"):
249248

250-
# Compute the group means
249+
import jax
250+
251+
sums = jax.ops.segment_sum(
252+
array, self.segment_ids, self.mask.pixels_in_mask
253+
)
254+
counts = jax.ops.segment_sum(
255+
xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
256+
)
257+
258+
else:
259+
260+
# Sum values per segment
261+
sums = np.bincount(self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask)
262+
263+
# Count number of items per segment
264+
counts = np.bincount(self.segment_ids, minlength=self.mask.pixels_in_mask)
265+
266+
# Avoid division by zero
267+
counts[counts == 0] = 1
251268

252-
sums = jax.ops.segment_sum(
253-
array, self.segment_ids, self.mask.pixels_in_mask
254-
)
255-
counts = jax.ops.segment_sum(
256-
xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
257-
)
258269
binned_array_2d = sums / counts
259270

260271
return Array2D(

autoarray/operators/transformer.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(
3939
uv_wavelengths: np.ndarray,
4040
real_space_mask: Mask2D,
4141
preload_transform: bool = True,
42-
xp=np,
4342
):
4443
"""
4544
A direct Fourier transform (DFT) operator for radio interferometric imaging.
@@ -112,9 +111,7 @@ def __init__(
112111
2.0 * self.grid.shape_native[1]
113112
)
114113

115-
self._xp = xp
116-
117-
def visibilities_from(self, image: Array2D) -> Visibilities:
114+
def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
118115
"""
119116
Computes the visibilities from a real-space image using the direct Fourier transform (DFT).
120117
@@ -138,19 +135,20 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
138135
image_1d=image.array,
139136
preloaded_reals=self.preload_real_transforms,
140137
preloaded_imags=self.preload_imag_transforms,
141-
xp=self._xp,
138+
xp=xp,
142139
)
143140
else:
144141
visibilities = transformer_util.visibilities_from(
145142
image_1d=image.slim.array,
146143
grid_radians=self.grid.array,
147144
uv_wavelengths=self.uv_wavelengths,
145+
xp=xp
148146
)
149147

150-
return Visibilities(visibilities=self._xp.array(visibilities))
148+
return Visibilities(visibilities=xp.array(visibilities))
151149

152150
def image_from(
153-
self, visibilities: Visibilities, use_adjoint_scaling: bool = False
151+
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
154152
) -> Array2D:
155153
"""
156154
Computes the real-space image from a set of visibilities using the adjoint of the DFT.
@@ -178,12 +176,12 @@ def image_from(
178176
)
179177

180178
image_native = array_2d_util.array_2d_native_from(
181-
array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=self._xp
179+
array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=xp
182180
)
183181

184182
return Array2D(values=image_native, mask=self.real_space_mask)
185183

186-
def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
184+
def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray:
187185
"""
188186
Applies the DFT to a mapping matrix that maps source pixels to image pixels.
189187
@@ -310,8 +308,6 @@ def __init__(
310308
2.0 * self.grid.shape_native[1]
311309
)
312310

313-
self._xp = xp
314-
315311
def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6)):
316312
"""
317313
Initializes the PyNUFFT plan for performing the NUFFT operation.
@@ -394,7 +390,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
394390
)
395391

396392
def image_from(
397-
self, visibilities: Visibilities, use_adjoint_scaling: bool = False
393+
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
398394
) -> Array2D:
399395
"""
400396
Reconstructs a real-space image from visibilities using the NUFFT adjoint transform.
@@ -425,24 +421,24 @@ def image_from(
425421

426422
return Array2D(values=image, mask=self.real_space_mask)
427423

428-
def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
424+
def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray:
429425
"""
430-
Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
426+
Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
431427
432-
Parameters
433-
----------
434-
mapping_matrix
435-
A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
428+
Parameters
429+
----------
430+
mapping_matrix
431+
A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
436432
437-
Returns
433+
Returns
438434
-------
439-
A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
440-
in the input mapping matrix.
435+
A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
436+
in the input mapping matrix.
441437
442-
Notes
443-
-----
444-
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
445-
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
438+
Notes
439+
-----
440+
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
441+
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
446442
"""
447443
transformed_mapping_matrix = 0 + 0j * np.zeros(
448444
(self.uv_wavelengths.shape[0], mapping_matrix.shape[1])
@@ -452,7 +448,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
452448
image_2d = array_2d_util.array_2d_native_from(
453449
array_2d_slim=mapping_matrix[:, source_pixel_1d_index],
454450
mask_2d=self.grid.mask,
455-
xp=self._xp,
451+
xp=xp,
456452
)
457453

458454
image = Array2D(values=image_2d, mask=self.grid.mask)

autoarray/operators/transformer_util.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def visibilities_via_preload_from(
120120

121121

122122
def visibilities_from(
123-
image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray
123+
image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
124124
) -> np.ndarray:
125125
"""
126126
Compute complex visibilities from an input sky image using the Fourier transform,
@@ -150,19 +150,19 @@ def visibilities_from(
150150
# Compute the dot product for each pixel-uv pair
151151
phase = (
152152
-2.0
153-
* np.pi
153+
* xp.pi
154154
* (
155-
np.outer(grid_radians[:, 1], uv_wavelengths[:, 0])
156-
+ np.outer(grid_radians[:, 0], uv_wavelengths[:, 1])
155+
xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0])
156+
+ xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1])
157157
)
158158
) # shape (n_pixels, n_vis)
159159

160160
# Multiply image values with phase terms
161-
vis_real = image_1d[:, None] * np.cos(phase)
162-
vis_imag = image_1d[:, None] * np.sin(phase)
161+
vis_real = image_1d[:, None] * xp.cos(phase)
162+
vis_imag = image_1d[:, None] * xp.sin(phase)
163163

164164
# Sum over all pixels for each visibility
165-
visibilities = np.sum(vis_real + 1j * vis_imag, axis=0)
165+
visibilities = xp.sum(vis_real + 1j * vis_imag, axis=0)
166166

167167
return visibilities
168168

@@ -247,7 +247,7 @@ def transformed_mapping_matrix_via_preload_from(
247247

248248

249249
def transformed_mapping_matrix_from(
250-
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray
250+
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
251251
) -> np.ndarray:
252252
"""
253253
Computes the Fourier-transformed mapping matrix used in radio interferometric imaging.
@@ -273,16 +273,16 @@ def transformed_mapping_matrix_from(
273273
# Compute phase term: (n_image_pixels, n_visibilities)
274274
phase = (
275275
-2.0
276-
* np.pi
276+
* xp.pi
277277
* (
278-
np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u
279-
+ np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v
278+
xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u
279+
+ xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v
280280
)
281281
)
282282

283283
# Compute real and imaginary Fourier matrices
284-
fourier_real = np.cos(phase)
285-
fourier_imag = np.sin(phase)
284+
fourier_real = xp.cos(phase)
285+
fourier_imag = xp.sin(phase)
286286

287287
# Only compute contributions from non-zero mapping entries
288288
# This matrix multiplication is: (n_visibilities x n_image_pixels) dot (n_image_pixels x n_source_pixels)

0 commit comments

Comments
 (0)