Skip to content

Commit afb6e90

Browse files
Jammy2211Jammy2211
authored andcommitted
tming of functions including JAX one improved
1 parent 7979d3c commit afb6e90

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

autoarray/dataset/interferometer/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def apply_w_tilde(
193193

194194
if curvature_preload is None:
195195

196-
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
196+
logger.info("INTERFEROMETER Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.")
197197

198198
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
199199
noise_map_real=self.noise_map.array.real,

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from tqdm import tqdm
55
import os
6+
import time
67

78
logger = logging.getLogger(__name__)
89

@@ -263,7 +264,6 @@ def w_tilde_curvature_preload_interferometer_via_np_from(
263264
if chunk_k <= 0:
264265
raise ValueError("chunk_k must be a positive integer")
265266

266-
# Enforce float64 everywhere (matches your original implementation)
267267
noise_map_real = np.asarray(noise_map_real, dtype=np.float64)
268268
uv_wavelengths = np.asarray(uv_wavelengths, dtype=np.float64)
269269
grid_radians_2d = np.asarray(grid_radians_2d, dtype=np.float64)
@@ -274,8 +274,9 @@ def w_tilde_curvature_preload_interferometer_via_np_from(
274274
gx = grid[..., 1]
275275

276276
K = uv_wavelengths.shape[0]
277+
n_chunks = (K + chunk_k - 1) // chunk_k
277278

278-
w = 1.0 / (noise_map_real**2)
279+
w = 1.0 / (noise_map_real ** 2)
279280
ku = 2.0 * np.pi * uv_wavelengths[:, 0]
280281
kv = 2.0 * np.pi * uv_wavelengths[:, 1]
281282

@@ -287,31 +288,44 @@ def w_tilde_curvature_preload_interferometer_via_np_from(
287288
ym0, xm0 = gy[y_shape - 1, 0], gx[y_shape - 1, 0]
288289
ymm, xmm = gy[y_shape - 1, x_shape - 1], gx[y_shape - 1, x_shape - 1]
289290

290-
def accum_from_corner_np(y_ref, x_ref, gy_block, gx_block, label=""):
291+
# -------------------------------------------------
292+
# Set up a single global progress bar
293+
# -------------------------------------------------
294+
pbar = None
295+
if show_progress:
296+
try:
297+
from tqdm import tqdm # type: ignore
298+
299+
n_quadrants = 1
300+
if x_shape > 1:
301+
n_quadrants += 1
302+
if y_shape > 1:
303+
n_quadrants += 1
304+
if (y_shape > 1) and (x_shape > 1):
305+
n_quadrants += 1
306+
307+
pbar = tqdm(
308+
total=n_chunks * n_quadrants,
309+
desc="Accumulating visibilities (W-tilde preload)",
310+
)
311+
except Exception:
312+
pbar = None
313+
314+
def accum_from_corner_np(y_ref, x_ref, gy_block, gx_block):
291315
dy = y_ref - gy_block
292316
dx = x_ref - gx_block
293317

294318
acc = np.zeros(gy_block.shape, dtype=np.float64)
295319

296-
iterator = range(0, K, chunk_k)
297-
if show_progress:
298-
try:
299-
from tqdm import tqdm # type: ignore
300-
301-
iterator = tqdm(
302-
iterator,
303-
desc=f"Accumulating visibilities {label}",
304-
total=(K + chunk_k - 1) // chunk_k,
305-
)
306-
except Exception:
307-
pass # tqdm not installed; silently fall back
308-
309-
for k0 in iterator:
320+
for k0 in range(0, K, chunk_k):
310321
k1 = min(K, k0 + chunk_k)
311322

312323
phase = dx[..., None] * ku[k0:k1] + dy[..., None] * kv[k0:k1]
313324
acc += np.sum(np.cos(phase) * w[k0:k1], axis=2)
314325

326+
if pbar is not None:
327+
pbar.update(1)
328+
315329
if show_memory and show_progress and "_report_memory" in globals():
316330
try:
317331
globals()["_report_memory"](acc)
@@ -323,31 +337,32 @@ def accum_from_corner_np(y_ref, x_ref, gy_block, gx_block, label=""):
323337
# -----------------------------
324338
# Main quadrant (+,+)
325339
# -----------------------------
326-
out[:y_shape, :x_shape] = accum_from_corner_np(y00, x00, gy, gx, label="(+,+)")
340+
out[:y_shape, :x_shape] = accum_from_corner_np(y00, x00, gy, gx)
327341

328342
# -----------------------------
329343
# Flip in x (+,-)
330344
# -----------------------------
331345
if x_shape > 1:
332-
block = accum_from_corner_np(y0m, x0m, gy[:, ::-1], gx[:, ::-1], label="(+,-)")
346+
block = accum_from_corner_np(y0m, x0m, gy[:, ::-1], gx[:, ::-1])
333347
out[:y_shape, -1:-(x_shape):-1] = block[:, 1:]
334348

335349
# -----------------------------
336350
# Flip in y (-,+)
337351
# -----------------------------
338352
if y_shape > 1:
339-
block = accum_from_corner_np(ym0, xm0, gy[::-1, :], gx[::-1, :], label="(-,+)")
353+
block = accum_from_corner_np(ym0, xm0, gy[::-1, :], gx[::-1, :])
340354
out[-1:-(y_shape):-1, :x_shape] = block[1:, :]
341355

342356
# -----------------------------
343357
# Flip in x and y (-,-)
344358
# -----------------------------
345359
if (y_shape > 1) and (x_shape > 1):
346-
block = accum_from_corner_np(
347-
ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1], label="(-,-)"
348-
)
360+
block = accum_from_corner_np(ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1])
349361
out[-1:-(y_shape):-1, -1:-(x_shape):-1] = block[1:, 1:]
350362

363+
if pbar is not None:
364+
pbar.close()
365+
351366
return out
352367

353368

@@ -466,7 +481,13 @@ def body(i, acc_):
466481
_compute_all_quadrants, static_argnames=("chunk_k",)
467482
)
468483

484+
t0 = time.time()
469485
out = _compute_all_quadrants_jit(gy, gx, chunk_k=chunk_k)
486+
out.block_until_ready() # ensure timing includes actual device execution
487+
t1 = time.time()
488+
489+
logger.info("INTERFEROMETER - Finished W-Tilde (JAX) in %.3f seconds", (t1 - t0))
490+
470491
return np.asarray(out)
471492

472493

0 commit comments

Comments
 (0)