33import numpy as np
44from tqdm import tqdm
55import os
6+ import time
67
78logger = logging .getLogger (__name__ )
89
@@ -263,7 +264,6 @@ def w_tilde_curvature_preload_interferometer_via_np_from(
263264 if chunk_k <= 0 :
264265 raise ValueError ("chunk_k must be a positive integer" )
265266
266- # Enforce float64 everywhere (matches your original implementation)
267267 noise_map_real = np .asarray (noise_map_real , dtype = np .float64 )
268268 uv_wavelengths = np .asarray (uv_wavelengths , dtype = np .float64 )
269269 grid_radians_2d = np .asarray (grid_radians_2d , dtype = np .float64 )
@@ -274,8 +274,9 @@ def w_tilde_curvature_preload_interferometer_via_np_from(
274274 gx = grid [..., 1 ]
275275
276276 K = uv_wavelengths .shape [0 ]
277+ n_chunks = (K + chunk_k - 1 ) // chunk_k
277278
278- w = 1.0 / (noise_map_real ** 2 )
279+ w = 1.0 / (noise_map_real ** 2 )
279280 ku = 2.0 * np .pi * uv_wavelengths [:, 0 ]
280281 kv = 2.0 * np .pi * uv_wavelengths [:, 1 ]
281282
@@ -287,31 +288,44 @@ def w_tilde_curvature_preload_interferometer_via_np_from(
287288 ym0 , xm0 = gy [y_shape - 1 , 0 ], gx [y_shape - 1 , 0 ]
288289 ymm , xmm = gy [y_shape - 1 , x_shape - 1 ], gx [y_shape - 1 , x_shape - 1 ]
289290
290- def accum_from_corner_np (y_ref , x_ref , gy_block , gx_block , label = "" ):
291+ # -------------------------------------------------
292+ # Set up a single global progress bar
293+ # -------------------------------------------------
294+ pbar = None
295+ if show_progress :
296+ try :
297+ from tqdm import tqdm # type: ignore
298+
299+ n_quadrants = 1
300+ if x_shape > 1 :
301+ n_quadrants += 1
302+ if y_shape > 1 :
303+ n_quadrants += 1
304+ if (y_shape > 1 ) and (x_shape > 1 ):
305+ n_quadrants += 1
306+
307+ pbar = tqdm (
308+ total = n_chunks * n_quadrants ,
309+ desc = "Accumulating visibilities (W-tilde preload)" ,
310+ )
311+ except Exception :
312+ pbar = None
313+
314+ def accum_from_corner_np (y_ref , x_ref , gy_block , gx_block ):
291315 dy = y_ref - gy_block
292316 dx = x_ref - gx_block
293317
294318 acc = np .zeros (gy_block .shape , dtype = np .float64 )
295319
296- iterator = range (0 , K , chunk_k )
297- if show_progress :
298- try :
299- from tqdm import tqdm # type: ignore
300-
301- iterator = tqdm (
302- iterator ,
303- desc = f"Accumulating visibilities { label } " ,
304- total = (K + chunk_k - 1 ) // chunk_k ,
305- )
306- except Exception :
307- pass # tqdm not installed; silently fall back
308-
309- for k0 in iterator :
320+ for k0 in range (0 , K , chunk_k ):
310321 k1 = min (K , k0 + chunk_k )
311322
312323 phase = dx [..., None ] * ku [k0 :k1 ] + dy [..., None ] * kv [k0 :k1 ]
313324 acc += np .sum (np .cos (phase ) * w [k0 :k1 ], axis = 2 )
314325
326+ if pbar is not None :
327+ pbar .update (1 )
328+
315329 if show_memory and show_progress and "_report_memory" in globals ():
316330 try :
317331 globals ()["_report_memory" ](acc )
@@ -323,31 +337,32 @@ def accum_from_corner_np(y_ref, x_ref, gy_block, gx_block, label=""):
323337 # -----------------------------
324338 # Main quadrant (+,+)
325339 # -----------------------------
326- out [:y_shape , :x_shape ] = accum_from_corner_np (y00 , x00 , gy , gx , label = "(+,+)" )
340+ out [:y_shape , :x_shape ] = accum_from_corner_np (y00 , x00 , gy , gx )
327341
328342 # -----------------------------
329343 # Flip in x (+,-)
330344 # -----------------------------
331345 if x_shape > 1 :
332- block = accum_from_corner_np (y0m , x0m , gy [:, ::- 1 ], gx [:, ::- 1 ], label = "(+,-)" )
346+ block = accum_from_corner_np (y0m , x0m , gy [:, ::- 1 ], gx [:, ::- 1 ])
333347 out [:y_shape , - 1 :- (x_shape ):- 1 ] = block [:, 1 :]
334348
335349 # -----------------------------
336350 # Flip in y (-,+)
337351 # -----------------------------
338352 if y_shape > 1 :
339- block = accum_from_corner_np (ym0 , xm0 , gy [::- 1 , :], gx [::- 1 , :], label = "(-,+)" )
353+ block = accum_from_corner_np (ym0 , xm0 , gy [::- 1 , :], gx [::- 1 , :])
340354 out [- 1 :- (y_shape ):- 1 , :x_shape ] = block [1 :, :]
341355
342356 # -----------------------------
343357 # Flip in x and y (-,-)
344358 # -----------------------------
345359 if (y_shape > 1 ) and (x_shape > 1 ):
346- block = accum_from_corner_np (
347- ymm , xmm , gy [::- 1 , ::- 1 ], gx [::- 1 , ::- 1 ], label = "(-,-)"
348- )
360+ block = accum_from_corner_np (ymm , xmm , gy [::- 1 , ::- 1 ], gx [::- 1 , ::- 1 ])
349361 out [- 1 :- (y_shape ):- 1 , - 1 :- (x_shape ):- 1 ] = block [1 :, 1 :]
350362
363+ if pbar is not None :
364+ pbar .close ()
365+
351366 return out
352367
353368
@@ -466,7 +481,13 @@ def body(i, acc_):
466481 _compute_all_quadrants , static_argnames = ("chunk_k" ,)
467482 )
468483
484+ t0 = time .time ()
469485 out = _compute_all_quadrants_jit (gy , gx , chunk_k = chunk_k )
486+ out .block_until_ready () # ensure timing includes actual device execution
487+ t1 = time .time ()
488+
489+ logger .info ("INTERFEROMETER - Finished W-Tilde (JAX) in %.3f seconds" , (t1 - t0 ))
490+
470491 return np .asarray (out )
471492
472493
0 commit comments