Skip to content

Commit 87b6ff8

Browse files
Jammy2211Jammy2211
authored andcommitted
fix or remove curvature preload unit tests
1 parent c9605b5 commit 87b6ff8

File tree

4 files changed

+49
-121
lines changed

4 files changed

+49
-121
lines changed

autoarray/dataset/interferometer/w_tilde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
inversion_interferometer_util,
4646
)
4747

48-
self.operator_state = inversion_interferometer_util.w_tilde_fft_state_from(
48+
self.fft_state = inversion_interferometer_util.w_tilde_fft_state_from(
4949
curvature_preload=self.curvature_preload, batch_size=450
5050
)
5151

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def _report_memory(arr):
6767
"""
6868
try:
6969
import resource
70+
7071
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
7172
arr_mb = arr.nbytes / 1024**2
7273
from tqdm import tqdm
73-
tqdm.write(
74-
f" Memory: array={arr_mb:.1f} MB, RSS≈{rss_mb:.1f} MB"
75-
)
74+
75+
tqdm.write(f" Memory: array={arr_mb:.1f} MB, RSS≈{rss_mb:.1f} MB")
7676
except Exception:
7777
pass
7878

@@ -141,26 +141,26 @@ def w_tilde_curvature_preload_interferometer_from(
141141
- The values of pixels paired with themselves are also computed repeatedly for the standard calculation (e.g. 9
142142
times using the mask above).
143143
144-
The `w_tilde_preload` method instead only computes each value once. To do this, it stores the preload values in a
144+
The `curvature_preload` method instead only computes each value once. To do this, it stores the preload values in a
145145
matrix of dimensions [shape_masked_pixels_y, shape_masked_pixels_x, 2], where `shape_masked_pixels` is the (y,x)
146146
size of the vertical and horizontal extent of unmasked pixels, e.g. the spatial extent over which the real space
147147
grid extends.
148148
149-
Each entry in the matrix `w_tilde_preload[:,:,0]` provides the the precomputed NUFFT value mapping an image pixel
149+
Each entry in the matrix `curvature_preload[:,:,0]` provides the the precomputed NUFFT value mapping an image pixel
150150
to a pixel offset by that much in the y and x directions, for example:
151151
152-
- w_tilde_preload[0,0,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and
152+
- curvature_preload[0,0,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and
153153
in the x direction by 0 - the values of pixels paired with themselves.
154-
- w_tilde_preload[1,0,0] gives the precomputed values of image pixels that are offset in the y direction by 1 and
154+
- curvature_preload[1,0,0] gives the precomputed values of image pixels that are offset in the y direction by 1 and
155155
in the x direction by 0 - the values of pixel pairs [0,3], [1,4], [2,5], [3,6], [4,7] and [5,8]
156-
- w_tilde_preload[0,1,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and
156+
- curvature_preload[0,1,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and
157157
in the x direction by 1 - the values of pixel pairs [0,1], [1,2], [3,4], [4,5], [6,7] and [7,9].
158158
159159
Flipped pairs:
160160
161161
The above preloaded values pair all image pixel NUFFT values when a pixel is to the right and / or down of the
162162
first image pixel. However, one must also precompute pairs where the paired pixel is to the left of the host
163-
pixels. These pairings are stored in `w_tilde_preload[:,:,1]`, and the ordering of these pairings is flipped in the
163+
pixels. These pairings are stored in `curvature_preload[:,:,1]`, and the ordering of these pairings is flipped in the
164164
x direction to make it straight forward to use this matrix when computing w_tilde.
165165
166166
Parameters
@@ -196,7 +196,7 @@ def w_tilde_curvature_preload_interferometer_from(
196196

197197
K = uv_wavelengths.shape[0]
198198

199-
w = 1.0 / (noise_map_real ** 2)
199+
w = 1.0 / (noise_map_real**2)
200200
ku = 2.0 * np.pi * uv_wavelengths[:, 0]
201201
kv = 2.0 * np.pi * uv_wavelengths[:, 1]
202202

@@ -225,10 +225,7 @@ def accum_from_corner(y_ref, x_ref, gy_block, gx_block, label=""):
225225
for k0 in iterator:
226226
k1 = min(K, k0 + chunk_k)
227227

228-
phase = (
229-
dx[..., None] * ku[k0:k1]
230-
+ dy[..., None] * kv[k0:k1]
231-
)
228+
phase = dx[..., None] * ku[k0:k1] + dy[..., None] * kv[k0:k1]
232229
acc += np.sum(
233230
np.cos(phase) * w[k0:k1],
234231
axis=2,
@@ -242,27 +239,21 @@ def accum_from_corner(y_ref, x_ref, gy_block, gx_block, label=""):
242239
# -----------------------------
243240
# Main quadrant (+,+)
244241
# -----------------------------
245-
out[:y_shape, :x_shape] = accum_from_corner(
246-
y00, x00, gy, gx, label="(+,+)"
247-
)
242+
out[:y_shape, :x_shape] = accum_from_corner(y00, x00, gy, gx, label="(+,+)")
248243

249244
# -----------------------------
250245
# Flip in x (+,-)
251246
# -----------------------------
252247
if x_shape > 1:
253-
block = accum_from_corner(
254-
y0m, x0m, gy[:, ::-1], gx[:, ::-1], label="(+,-)"
255-
)
256-
out[:y_shape, -1:-(x_shape): -1] = block[:, 1:]
248+
block = accum_from_corner(y0m, x0m, gy[:, ::-1], gx[:, ::-1], label="(+,-)")
249+
out[:y_shape, -1:-(x_shape):-1] = block[:, 1:]
257250

258251
# -----------------------------
259252
# Flip in y (-,+)
260253
# -----------------------------
261254
if y_shape > 1:
262-
block = accum_from_corner(
263-
ym0, xm0, gy[::-1, :], gx[::-1, :], label="(-,+)"
264-
)
265-
out[-1:-(y_shape): -1, :x_shape] = block[1:, :]
255+
block = accum_from_corner(ym0, xm0, gy[::-1, :], gx[::-1, :], label="(-,+)")
256+
out[-1:-(y_shape):-1, :x_shape] = block[1:, :]
266257

267258
# -----------------------------
268259
# Flip in x and y (-,-)
@@ -271,21 +262,19 @@ def accum_from_corner(y_ref, x_ref, gy_block, gx_block, label=""):
271262
block = accum_from_corner(
272263
ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1], label="(-,-)"
273264
)
274-
out[-1:-(y_shape): -1, -1:-(x_shape): -1] = block[1:, 1:]
265+
out[-1:-(y_shape):-1, -1:-(x_shape):-1] = block[1:, 1:]
275266

276267
return out
277268

278269

279-
280-
281-
def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index):
270+
def w_tilde_via_preload_from(curvature_preload, native_index_for_slim_index):
282271
"""
283-
Use the preloaded w_tilde matrix (see `w_tilde_preload_interferometer_from`) to compute
272+
Use the preloaded w_tilde matrix (see `curvature_preload_interferometer_from`) to compute
284273
w_tilde (see `w_tilde_interferometer_from`) efficiently.
285274
286275
Parameters
287276
----------
288-
w_tilde_preload
277+
curvature_preload
289278
The preloaded values of the NUFFT that enable efficient computation of w_tilde.
290279
native_index_for_slim_index
291280
An array of shape [total_unmasked_pixels*sub_size] that maps every unmasked sub-pixel to its corresponding
@@ -311,7 +300,7 @@ def w_tilde_via_preload_from(w_tilde_preload, native_index_for_slim_index):
311300
y_diff = j_y - i_y
312301
x_diff = j_x - i_x
313302

314-
w_tilde_via_preload[i, j] = w_tilde_preload[y_diff, x_diff]
303+
w_tilde_via_preload[i, j] = curvature_preload[y_diff, x_diff]
315304

316305
for i in range(slim_size):
317306
for j in range(i, slim_size):

autoarray/inversion/inversion/interferometer/w_tilde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def curvature_matrix_diag(self) -> np.ndarray:
110110
mapper = self.cls_list_from(cls=AbstractMapper)[0]
111111

112112
return inversion_interferometer_util.curvature_matrix_via_w_tilde_interferometer_from(
113-
fft_state=self.w_tilde.operator_state,
113+
fft_state=self.w_tilde.fft_state,
114114
pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index,
115115
pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index,
116116
pix_pixels=self.linear_obj_list[0].params,

test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py

Lines changed: 26 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -68,49 +68,14 @@ def test__data_vector_via_transformed_mapping_matrix_from():
6868
assert (data_vector_complex_via_blurred == data_vector_via_transformed).all()
6969

7070

71-
def test__w_tilde_curvature_interferometer_from():
72-
noise_map = np.array([1.0, 2.0, 3.0])
73-
uv_wavelengths = np.array([[0.0001, 2.0, 3000.0], [3000.0, 2.0, 0.0001]])
74-
75-
grid = aa.Grid2D.uniform(shape_native=(2, 2), pixel_scales=0.0005)
76-
77-
w_tilde = (
78-
aa.util.inversion_interferometer_numba.w_tilde_curvature_interferometer_from(
79-
noise_map_real=noise_map,
80-
uv_wavelengths=uv_wavelengths,
81-
grid_radians_slim=grid.array,
82-
)
83-
)
84-
85-
assert w_tilde == pytest.approx(
86-
np.array(
87-
[
88-
[1.25, 0.75, 1.24997, 0.74998],
89-
[0.75, 1.25, 0.74998, 1.24997],
90-
[1.24994, 0.74998, 1.25, 0.75],
91-
[0.74998, 1.24997, 0.75, 1.25],
92-
]
93-
),
94-
1.0e-4,
95-
)
96-
97-
98-
def test__curvature_matrix_via_w_tilde_preload_from():
71+
def test__curvature_matrix_via_curvature_preload_from():
9972
noise_map = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
10073
uv_wavelengths = np.array(
10174
[[0.0001, 2.0, 3000.0, 50000.0, 200000.0], [3000.0, 2.0, 0.0001, 10.0, 5000.0]]
10275
)
10376

10477
grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=0.0005)
10578

106-
w_tilde = (
107-
aa.util.inversion_interferometer_numba.w_tilde_curvature_interferometer_from(
108-
noise_map_real=noise_map,
109-
uv_wavelengths=uv_wavelengths,
110-
grid_radians_slim=grid.array,
111-
)
112-
)
113-
11479
mapping_matrix = np.array(
11580
[
11681
[1.0, 0.0, 0.0],
@@ -125,34 +90,45 @@ def test__curvature_matrix_via_w_tilde_preload_from():
12590
]
12691
)
12792

128-
curvature_matrix_via_w_tilde = aa.util.inversion.curvature_matrix_via_w_tilde_from(
129-
w_tilde=w_tilde, mapping_matrix=mapping_matrix
93+
curvature_preload = (
94+
aa.util.inversion_interferometer.w_tilde_curvature_preload_interferometer_from(
95+
noise_map_real=noise_map,
96+
uv_wavelengths=uv_wavelengths,
97+
shape_masked_pixels_2d=(3, 3),
98+
grid_radians_2d=np.array(grid.native),
99+
)
130100
)
131101

132-
w_tilde_preload = aa.util.inversion_interferometer_numba.w_tilde_curvature_preload_interferometer_from(
133-
noise_map_real=noise_map,
134-
uv_wavelengths=uv_wavelengths,
135-
shape_masked_pixels_2d=(3, 3),
136-
grid_radians_2d=np.array(grid.native),
102+
native_index_for_slim_index = np.array(
103+
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]]
104+
)
105+
106+
w_tilde = aa.util.inversion_interferometer.w_tilde_via_preload_from(
107+
curvature_preload=curvature_preload,
108+
native_index_for_slim_index=native_index_for_slim_index,
109+
)
110+
111+
curvature_matrix_via_w_tilde = aa.util.inversion.curvature_matrix_via_w_tilde_from(
112+
w_tilde=w_tilde, mapping_matrix=mapping_matrix
137113
)
138114

139115
pix_indexes_for_sub_slim_index = np.array(
140116
[[0], [2], [1], [1], [2], [2], [0], [2], [0]]
141117
)
142118

143-
pix_size_for_sub_slim_index = np.ones(shape=(9,)).astype("int")
144119
pix_weights_for_sub_slim_index = np.ones(shape=(9, 1))
145120

146-
native_index_for_slim_index = np.array(
147-
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]]
121+
w_tilde = aa.WTildeInterferometer(
122+
curvature_preload=curvature_preload,
123+
dirty_image=None,
124+
real_space_mask=grid.mask,
148125
)
149126

150-
curvature_matrix_via_preload = aa.util.inversion_interferometer_numba.curvature_matrix_via_w_tilde_curvature_preload_interferometer_from(
151-
curvature_preload=w_tilde_preload,
127+
curvature_matrix_via_preload = aa.util.inversion_interferometer.curvature_matrix_via_w_tilde_interferometer_from(
128+
fft_state=w_tilde.fft_state,
152129
pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index,
153-
pix_size_for_sub_slim_index=pix_size_for_sub_slim_index,
154130
pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index,
155-
native_index_for_slim_index=native_index_for_slim_index,
131+
rect_index_for_mask_index=w_tilde.rect_index_for_mask_index,
156132
pix_pixels=3,
157133
)
158134

@@ -161,43 +137,6 @@ def test__curvature_matrix_via_w_tilde_preload_from():
161137
)
162138

163139

164-
def test__curvature_matrix_via_w_tilde_two_methods_agree():
165-
noise_map = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
166-
uv_wavelengths = np.array(
167-
[[0.0001, 2.0, 3000.0, 50000.0, 200000.0], [3000.0, 2.0, 0.0001, 10.0, 5000.0]]
168-
)
169-
170-
grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=0.0005)
171-
172-
w_tilde = (
173-
aa.util.inversion_interferometer_numba.w_tilde_curvature_interferometer_from(
174-
noise_map_real=noise_map,
175-
uv_wavelengths=uv_wavelengths,
176-
grid_radians_slim=grid.array,
177-
)
178-
)
179-
180-
w_tilde_preload = aa.util.inversion_interferometer_numba.w_tilde_curvature_preload_interferometer_from(
181-
noise_map_real=np.array(noise_map),
182-
uv_wavelengths=np.array(uv_wavelengths),
183-
shape_masked_pixels_2d=(3, 3),
184-
grid_radians_2d=np.array(grid.native),
185-
)
186-
187-
native_index_for_slim_index = np.array(
188-
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]]
189-
)
190-
191-
w_tilde_via_preload = (
192-
aa.util.inversion_interferometer_numba.w_tilde_via_preload_from(
193-
w_tilde_preload=w_tilde_preload,
194-
native_index_for_slim_index=native_index_for_slim_index,
195-
)
196-
)
197-
198-
assert (w_tilde == w_tilde_via_preload).all()
199-
200-
201140
def test__identical_inversion_values_for_two_methods():
202141
real_space_mask = aa.Mask2D.all_false(
203142
shape_native=(7, 7),

0 commit comments

Comments
 (0)