Skip to content

Commit 266ffd8

Browse files
authored
Merge pull request #194 from Jammy2211/feature/simplify_slim_to_native_tuple
Feature/simplify slim to native tuple
2 parents 2b31e98 + 21ccc39 commit 266ffd8

File tree

15 files changed

+69
-169
lines changed

15 files changed

+69
-169
lines changed

autoarray/abstract_ndarray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def __getitem__(self, item):
338338

339339
try:
340340
import jax.numpy as jnp
341+
341342
if isinstance(result, jnp.ndarray):
342343
result = self.with_new_array(result)
343344
except ImportError:
@@ -351,6 +352,7 @@ def __setitem__(self, key, value):
351352
self._array[key] = value
352353
else:
353354
import jax.numpy as jnp
355+
354356
self._array = jnp.where(key, value, self._array)
355357

356358
def __repr__(self):

autoarray/dataset/imaging/dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from pathlib import Path
44
from typing import Optional, Union
55

6+
from autoconf import cached_property
7+
68
from autoarray.dataset.abstract.dataset import AbstractDataset
79
from autoarray.dataset.grids import GridsDataset
810
from autoarray.dataset.imaging.w_tilde import WTildeImaging
@@ -191,7 +193,7 @@ def __init__(
191193
psf=self.psf,
192194
)
193195

194-
@property
196+
@cached_property
195197
def w_tilde(self):
196198
"""
197199
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked

autoarray/dataset/interferometer/dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import numpy as np
33
from pathlib import Path
44

5+
from autoconf import cached_property
6+
57
from autoconf.fitsable import ndarray_via_fits_from, output_to_fits
68

79
from autoarray.dataset.abstract.dataset import AbstractDataset
@@ -14,6 +16,8 @@
1416

1517
from autoarray.inversion.inversion.interferometer import inversion_interferometer_util
1618

19+
from autoarray import exc
20+
1721
logger = logging.getLogger(__name__)
1822

1923

@@ -165,7 +169,7 @@ def w_tilde_preprocessing(self):
165169

166170
fits.writeto(filename, data=curvature_preload)
167171

168-
@property
172+
@cached_property
169173
def w_tilde(self):
170174
"""
171175
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities

autoarray/fit/fit_util.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def residual_map_with_mask_from(
198198
model_data
199199
The model data used to fit the data.
200200
"""
201-
return xp.where(xp.asarray(mask) == 0, xp.subtract(data, model_data), 0)
201+
return xp.where(mask == 0, xp.subtract(data, model_data), 0)
202202

203203

204204
@to_new_array
@@ -221,7 +221,7 @@ def normalized_residual_map_with_mask_from(
221221
mask
222222
The mask applied to the residual-map, where `False` entries are included in the calculation.
223223
"""
224-
return xp.where(xp.asarray(mask) == 0, xp.divide(residual_map, noise_map), 0)
224+
return xp.where(mask == 0, xp.divide(residual_map, noise_map), 0)
225225

226226

227227
@to_new_array
@@ -244,7 +244,7 @@ def chi_squared_map_with_mask_from(
244244
mask
245245
The mask applied to the residual-map, where `False` entries are included in the calculation.
246246
"""
247-
return xp.where(xp.asarray(mask) == 0, xp.square(residual_map / noise_map), 0)
247+
return xp.where(mask == 0, xp.square(residual_map / noise_map), 0)
248248

249249

250250
def chi_squared_with_mask_from(
@@ -263,7 +263,7 @@ def chi_squared_with_mask_from(
263263
mask
264264
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
265265
"""
266-
return float(xp.sum(chi_squared_map[xp.asarray(mask) == 0]))
266+
return float(xp.sum(chi_squared_map[mask == 0]))
267267

268268

269269
def chi_squared_with_mask_fast_from(
@@ -301,8 +301,8 @@ def chi_squared_with_mask_fast_from(
301301
xp.subtract(
302302
data,
303303
model_data,
304-
)[xp.asarray(mask) == 0],
305-
noise_map[xp.asarray(mask) == 0],
304+
)[mask == 0],
305+
noise_map[mask == 0],
306306
)
307307
)
308308
)
@@ -326,7 +326,7 @@ def noise_normalization_with_mask_from(
326326
mask
327327
The mask applied to the noise-map, where `False` entries are included in the calculation.
328328
"""
329-
return float(xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0)))
329+
return float(xp.sum(xp.log(2 * xp.pi * noise_map[mask == 0] ** 2.0)))
330330

331331

332332
def chi_squared_with_noise_covariance_from(

autoarray/inversion/pixelization/mesh/rectangular.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,12 @@ def requires_image_mesh(self):
158158

159159
class RectangularSource(RectangularMagnification):
160160

161-
def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power: float = 1.0, weight_floor : float = 0.0):
161+
def __init__(
162+
self,
163+
shape: Tuple[int, int] = (3, 3),
164+
weight_power: float = 1.0,
165+
weight_floor: float = 0.0,
166+
):
162167
"""
163168
A uniform mesh of rectangular pixels, which without interpolation are paired with a 2D grid of (y,x)
164169
coordinates.
@@ -203,9 +208,9 @@ def mesh_weight_map_from(self, adapt_data, xp=np) -> np.ndarray:
203208
xp
204209
The array library to use.
205210
"""
206-
mesh_weight_map = xp.asarray(adapt_data.array)
211+
mesh_weight_map = adapt_data.array
207212
mesh_weight_map = xp.clip(mesh_weight_map, 1e-12, None)
208-
mesh_weight_map = mesh_weight_map ** self.weight_power
213+
mesh_weight_map = mesh_weight_map**self.weight_power
209214

210215
# Apply floor using xp.where (safe for NumPy and JAX)
211216
mesh_weight_map = xp.where(

autoarray/inversion/regularization/gaussian_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def gauss_cov_matrix_from(
3333
The Gaussian covariance matrix.
3434
"""
3535
# Ensure array:
36-
pts = xp.asarray(pixel_points) # (N, 2)
36+
pts = pixel_points # (N, 2)
3737
# Compute squared distances: ||p_i - p_j||^2
3838
diffs = pts[:, None, :] - pts[None, :, :] # (N, N, 2)
3939
d2 = xp.sum(diffs**2, axis=-1) # (N, N)

autoarray/mask/mask_2d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def __init__(
217217
xp=xp,
218218
)
219219

220+
slim_to_native = self.derive_indexes.native_for_slim.astype("int32")
221+
self.slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
222+
220223
@property
221224
def native_for_slim(self):
222225
return self.derive_indexes.native_for_slim

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,14 @@ def binned_array_2d_from(self, array: Array2D, xp=np) -> "Array2D":
258258
else:
259259

260260
# Sum values per segment
261-
sums = np.bincount(self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask)
261+
sums = np.bincount(
262+
self.segment_ids, weights=array, minlength=self.mask.pixels_in_mask
263+
)
262264

263265
# Count number of items per segment
264-
counts = np.bincount(self.segment_ids, minlength=self.mask.pixels_in_mask)
266+
counts = np.bincount(
267+
self.segment_ids, minlength=self.mask.pixels_in_mask
268+
)
265269

266270
# Avoid division by zero
267271
counts[counts == 0] = 1

autoarray/operators/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def visibilities_from(self, image: Array2D, xp=np) -> Visibilities:
142142
image_1d=image.slim.array,
143143
grid_radians=self.grid.array,
144144
uv_wavelengths=self.uv_wavelengths,
145-
xp=xp
145+
xp=xp,
146146
)
147147

148148
return Visibilities(visibilities=xp.array(visibilities))

autoarray/operators/transformer_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ def transformed_mapping_matrix_via_preload_from(
247247

248248

249249
def transformed_mapping_matrix_from(
250-
mapping_matrix: np.ndarray, grid_radians: np.ndarray, uv_wavelengths: np.ndarray, xp=np
250+
mapping_matrix: np.ndarray,
251+
grid_radians: np.ndarray,
252+
uv_wavelengths: np.ndarray,
253+
xp=np,
251254
) -> np.ndarray:
252255
"""
253256
Computes the Fourier-transformed mapping matrix used in radio interferometric imaging.

0 commit comments

Comments
 (0)