Skip to content

Commit 314e2d0

Browse files
Jammy2211Jammy2211
authored andcommitted
finish
1 parent 0c31531 commit 314e2d0

File tree

12 files changed

+732
-99
lines changed

12 files changed

+732
-99
lines changed

autoarray/abstract_ndarray.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,24 +331,27 @@ def __getattr__(self, item):
331331

332332
def __getitem__(self, item):
333333

334-
import jax.numpy as jnp
335-
336334
result = self._array[item]
337335

338336
if isinstance(item, slice):
339337
result = self.with_new_array(result)
340-
if isinstance(result, jnp.ndarray):
341-
result = self.with_new_array(result)
338+
339+
try:
340+
import jax.numpy as jnp
341+
if isinstance(result, jnp.ndarray):
342+
result = self.with_new_array(result)
343+
except ImportError:
344+
pass
345+
342346
return result
343347

344348
def __setitem__(self, key, value):
345-
from jax import Array
346-
import jax.numpy as jnp
347349

348-
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
349-
self._array = jnp.where(key, value, self._array)
350-
else:
350+
if isinstance(self._array, np.ndarray):
351351
self._array[key] = value
352+
else:
353+
import jax.numpy as jnp
354+
self._array = jnp.where(key, value, self._array)
352355

353356
def __repr__(self):
354357
return repr(self._array).replace(

autoarray/fit/fit_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def subtracted_from(grid, offset):
163163
if grid is None:
164164
return None
165165

166-
return grid.subtracted_from(offset=offset)
166+
return grid.subtracted_from(offset=offset, xp=self._xp)
167167

168168
lp = subtracted_from(
169169
grid=self.dataset.grids.lp, offset=self.dataset_model.grid_offset

autoarray/inversion/inversion/interferometer/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
6565
"""
6666
return [
6767
self.transformer.transform_mapping_matrix(
68-
mapping_matrix=linear_obj.mapping_matrix
68+
mapping_matrix=linear_obj.mapping_matrix, xp=self._xp
6969
)
7070
for linear_obj in self.linear_obj_list
7171
]

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
227227
Sub-pixels that are part of the same mask array pixel are indexed next to one another, such that the second
228228
sub-pixel in the first pixel has index 1, its next sub-pixel has index 2, and so forth.
229229
"""
230+
230231
if conf.instance["general"]["structures"]["native_binned_only"]:
231232
return self
232233

@@ -243,16 +244,28 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
243244

244245
else:
245246

246-
import jax
247+
if xp.__name__.startswith("jax"):
247248

248-
# Compute the group means
249+
import jax
250+
251+
sums = jax.ops.segment_sum(
252+
array, self.segment_ids, self.mask.pixels_in_mask
253+
)
254+
counts = jax.ops.segment_sum(
255+
xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
256+
)
257+
258+
else:
259+
260+
# Sum values per segment
261+
sums = np.bincount(self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask)
262+
263+
# Count number of items per segment
264+
counts = np.bincount(self.segment_ids, minlength=self.mask.pixels_in_mask)
265+
266+
# Avoid division by zero
267+
counts[counts == 0] = 1
249268

250-
sums = jax.ops.segment_sum(
251-
array, self.segment_ids, self.mask.pixels_in_mask
252-
)
253-
counts = jax.ops.segment_sum(
254-
xp.ones_like(array), self.segment_ids, self.mask.pixels_in_mask
255-
)
256269
binned_array_2d = sums / counts
257270

258271
return Array2D(

autoarray/operators/transformer.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(
3939
uv_wavelengths: np.ndarray,
4040
real_space_mask: Mask2D,
4141
preload_transform: bool = True,
42-
xp=np,
4342
):
4443
"""
4544
A direct Fourier transform (DFT) operator for radio interferometric imaging.
@@ -112,9 +111,7 @@ def __init__(
112111
2.0 * self.grid.shape_native[1]
113112
)
114113

115-
self._xp = xp
116-
117-
def visibilities_from(self, image: Array2D) -> Visibilities:
114+
def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
118115
"""
119116
Computes the visibilities from a real-space image using the direct Fourier transform (DFT).
120117
@@ -138,19 +135,20 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
138135
image_1d=image.array,
139136
preloaded_reals=self.preload_real_transforms,
140137
preloaded_imags=self.preload_imag_transforms,
141-
xp=self._xp,
138+
xp=xp,
142139
)
143140
else:
144141
visibilities = transformer_util.visibilities_from(
145142
image_1d=image.slim.array,
146143
grid_radians=self.grid.array,
147144
uv_wavelengths=self.uv_wavelengths,
145+
xp=xp
148146
)
149147

150-
return Visibilities(visibilities=self._xp.array(visibilities))
148+
return Visibilities(visibilities=xp.array(visibilities))
151149

152150
def image_from(
153-
self, visibilities: Visibilities, use_adjoint_scaling: bool = False
151+
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
154152
) -> Array2D:
155153
"""
156154
Computes the real-space image from a set of visibilities using the adjoint of the DFT.
@@ -178,12 +176,12 @@ def image_from(
178176
)
179177

180178
image_native = array_2d_util.array_2d_native_from(
181-
array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=self._xp
179+
array_2d_slim=image_slim, mask_2d=self.real_space_mask, xp=xp
182180
)
183181

184182
return Array2D(values=image_native, mask=self.real_space_mask)
185183

186-
def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
184+
def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray:
187185
"""
188186
Applies the DFT to a mapping matrix that maps source pixels to image pixels.
189187
@@ -310,8 +308,6 @@ def __init__(
310308
2.0 * self.grid.shape_native[1]
311309
)
312310

313-
self._xp = xp
314-
315311
def initialize_plan(self, ratio: int = 2, interp_kernel: Tuple[int, int] = (6, 6)):
316312
"""
317313
Initializes the PyNUFFT plan for performing the NUFFT operation.
@@ -394,7 +390,7 @@ def visibilities_from(self, image: Array2D) -> Visibilities:
394390
)
395391

396392
def image_from(
397-
self, visibilities: Visibilities, use_adjoint_scaling: bool = False
393+
self, visibilities: Visibilities, use_adjoint_scaling: bool = False, xp=np
398394
) -> Array2D:
399395
"""
400396
Reconstructs a real-space image from visibilities using the NUFFT adjoint transform.
@@ -425,24 +421,24 @@ def image_from(
425421

426422
return Array2D(values=image, mask=self.real_space_mask)
427423

428-
def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
424+
def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndarray:
429425
"""
430-
Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
426+
Applies the NUFFT forward transform to each column of a mapping matrix, producing transformed visibilities.
431427
432-
Parameters
433-
----------
434-
mapping_matrix
435-
A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
428+
Parameters
429+
----------
430+
mapping_matrix
431+
A 2D array where each column corresponds to a source-plane pixel intensity distribution flattened into image space.
436432
437-
Returns
433+
Returns
438434
-------
439-
A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
440-
in the input mapping matrix.
435+
A complex-valued 2D array where each column contains the visibilities corresponding to the respective column
436+
in the input mapping matrix.
441437
442-
Notes
443-
-----
444-
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
445-
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
438+
Notes
439+
-----
440+
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
441+
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
446442
"""
447443
transformed_mapping_matrix = 0 + 0j * np.zeros(
448444
(self.uv_wavelengths.shape[0], mapping_matrix.shape[1])
@@ -452,7 +448,7 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray) -> np.ndarray:
452448
image_2d = array_2d_util.array_2d_native_from(
453449
array_2d_slim=mapping_matrix[:, source_pixel_1d_index],
454450
mask_2d=self.grid.mask,
455-
xp=self._xp,
451+
xp=xp,
456452
)
457453

458454
image = Array2D(values=image_2d, mask=self.grid.mask)

autoarray/operators/transformer_util.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def visibilities_via_preload_from(
120120

121121

122122
def visibilities_from(
123-
image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray
123+
image_1d: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
124124
) -> np.ndarray:
125125
"""
126126
Compute complex visibilities from an input sky image using the Fourier transform,
@@ -150,19 +150,19 @@ def visibilities_from(
150150
# Compute the dot product for each pixel-uv pair
151151
phase = (
152152
-2.0
153-
* np.pi
153+
* xp.pi
154154
* (
155-
np.outer(grid_radians[:, 1], uv_wavelengths[:, 0])
156-
+ np.outer(grid_radians[:, 0], uv_wavelengths[:, 1])
155+
xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0])
156+
+ xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1])
157157
)
158158
) # shape (n_pixels, n_vis)
159159

160160
# Multiply image values with phase terms
161-
vis_real = image_1d[:, None] * np.cos(phase)
162-
vis_imag = image_1d[:, None] * np.sin(phase)
161+
vis_real = image_1d[:, None] * xp.cos(phase)
162+
vis_imag = image_1d[:, None] * xp.sin(phase)
163163

164164
# Sum over all pixels for each visibility
165-
visibilities = np.sum(vis_real + 1j * vis_imag, axis=0)
165+
visibilities = xp.sum(vis_real + 1j * vis_imag, axis=0)
166166

167167
return visibilities
168168

@@ -247,7 +247,7 @@ def transformed_mapping_matrix_via_preload_from(
247247

248248

249249
def transformed_mapping_matrix_from(
250-
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray
250+
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
251251
) -> np.ndarray:
252252
"""
253253
Computes the Fourier-transformed mapping matrix used in radio interferometric imaging.
@@ -273,16 +273,16 @@ def transformed_mapping_matrix_from(
273273
# Compute phase term: (n_image_pixels, n_visibilities)
274274
phase = (
275275
-2.0
276-
* np.pi
276+
* xp.pi
277277
* (
278-
np.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u
279-
+ np.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v
278+
xp.outer(grid_radians[:, 1], uv_wavelengths[:, 0]) # y * u
279+
+ xp.outer(grid_radians[:, 0], uv_wavelengths[:, 1]) # x * v
280280
)
281281
)
282282

283283
# Compute real and imaginary Fourier matrices
284-
fourier_real = np.cos(phase)
285-
fourier_imag = np.sin(phase)
284+
fourier_real = xp.cos(phase)
285+
fourier_imag = xp.sin(phase)
286286

287287
# Only compute contributions from non-zero mapping entries
288288
# This matrix multiplication is: (n_visibilities x n_image_pixels) dot (n_image_pixels x n_source_pixels)

autoarray/preloads.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
def mapper_indices_from(total_linear_light_profiles, total_mapper_pixels):
1111

12-
import jax.numpy as jnp
13-
14-
return jnp.arange(
12+
return np.arange(
1513
total_linear_light_profiles,
1614
total_linear_light_profiles + total_mapper_pixels,
1715
dtype=int,
@@ -54,31 +52,28 @@ def __init__(
5452
is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but
5553
the intensity values will still be solved for during the inversion.
5654
"""
57-
import jax.numpy as jnp
58-
5955
self.mapper_indices = None
6056
self.source_pixel_zeroed_indices = None
6157
self.source_pixel_zeroed_indices_to_keep = None
6258
self.linear_light_profile_blurred_mapping_matrix = None
6359

6460
if mapper_indices is not None:
6561

66-
self.mapper_indices = jnp.array(mapper_indices)
62+
self.mapper_indices = np.array(mapper_indices)
6763

6864
if source_pixel_zeroed_indices is not None:
6965

70-
self.source_pixel_zeroed_indices = jnp.array(source_pixel_zeroed_indices)
66+
self.source_pixel_zeroed_indices = np.array(source_pixel_zeroed_indices)
7167

72-
ids_zeros = jnp.array(source_pixel_zeroed_indices, dtype=int)
68+
ids_zeros = np.array(source_pixel_zeroed_indices, dtype=int)
7369

74-
values_to_solve = jnp.ones(np.max(mapper_indices), dtype=bool)
75-
values_to_solve = values_to_solve.at[ids_zeros].set(False)
70+
values_to_solve = np.ones(np.max(mapper_indices)+1, dtype=bool)
71+
values_to_solve[ids_zeros] = False
7672

77-
# Get the indices where values_to_solve is True
78-
self.source_pixel_zeroed_indices_to_keep = jnp.where(values_to_solve)[0]
73+
self.source_pixel_zeroed_indices_to_keep = np.where(values_to_solve)[0]
7974

8075
if linear_light_profile_blurred_mapping_matrix is not None:
8176

82-
self.linear_light_profile_blurred_mapping_matrix = jnp.array(
77+
self.linear_light_profile_blurred_mapping_matrix = np.array(
8378
linear_light_profile_blurred_mapping_matrix
8479
)

autoarray/structures/triangles/abstract.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,27 @@ def __len__(self):
1212
return len(self.triangles)
1313

1414
@property
15-
@abstractmethod
1615
def area(self) -> float:
1716
"""
1817
The total area covered by the triangles.
1918
"""
19+
triangles = self.triangles
20+
return (
21+
0.5
22+
* np.abs(
23+
(triangles[:, 0, 0] * (triangles[:, 1, 1] - triangles[:, 2, 1]))
24+
+ (triangles[:, 1, 0] * (triangles[:, 2, 1] - triangles[:, 0, 1]))
25+
+ (triangles[:, 2, 0] * (triangles[:, 0, 1] - triangles[:, 1, 1]))
26+
).sum()
27+
)
2028

2129
@property
22-
@abstractmethod
2330
def indices(self):
24-
pass
31+
return self._indices
2532

2633
@property
27-
@abstractmethod
2834
def vertices(self):
29-
pass
35+
return self._vertices
3036

3137
def __str__(self):
3238
return f"{self.__class__.__name__} with {len(self.indices)} triangles"

0 commit comments

Comments
 (0)