22from functools import partial
33from typing import Optional , List
44
5+
56def w_tilde_data_imaging_from (
67 image_native : np .ndarray ,
78 noise_map_native : np .ndarray ,
@@ -74,11 +75,11 @@ def w_tilde_data_imaging_from(
7475
7576
7677def data_vector_via_w_tilde_from (
77- w_tilde_data : np .ndarray , # (M_pix,) float64
78- rows : np .ndarray , # (nnz,) int32 each triplet's data pixel (slim index)
79- cols : np .ndarray , # (nnz,) int32 source pixel index
80- vals : np .ndarray , # (nnz,) float64 mapping weights incl sub_fraction
81- S : int , # number of source pixels
78+ w_tilde_data : np .ndarray , # (M_pix,) float64
79+ rows : np .ndarray , # (nnz,) int32 each triplet's data pixel (slim index)
80+ cols : np .ndarray , # (nnz,) int32 source pixel index
81+ vals : np .ndarray , # (nnz,) float64 mapping weights incl sub_fraction
82+ S : int , # number of source pixels
8283) -> np .ndarray :
8384 """
8485 Replacement for numba data_vector_via_w_tilde_data_imaging_from using triplets.
@@ -91,8 +92,8 @@ def data_vector_via_w_tilde_from(
9192 """
9293 from jax .ops import segment_sum
9394
94- w = w_tilde_data [rows ] # (nnz,)
95- contrib = vals * w # (nnz,)
95+ w = w_tilde_data [rows ] # (nnz,)
96+ contrib = vals * w # (nnz,)
9697 return segment_sum (contrib , cols , num_segments = S ) # (S,)
9798
9899
@@ -215,6 +216,7 @@ def curvature_matrix_mirrored_from(
215216
216217 return curvature_matrix_mirrored
217218
219+
218220def curvature_matrix_with_added_to_diag_from (
219221 curvature_matrix ,
220222 value : float ,
@@ -271,7 +273,7 @@ def curvature_matrix_with_added_to_diag_from(
271273def build_inv_noise_var (noise ):
272274 inv = np .zeros_like (noise , dtype = np .float64 )
273275 good = np .isfinite (noise ) & (noise > 0 )
274- inv [good ] = 1.0 / noise [good ]** 2
276+ inv [good ] = 1.0 / noise [good ] ** 2
275277 return inv
276278
277279
@@ -290,7 +292,9 @@ def precompute_Khat_rfft(kernel_2d: np.ndarray, fft_shape):
290292 return jnp .fft .rfft2 (kernel_pad , s = (Fy , Fx ))
291293
292294
293- def rfft_convolve2d_same (images : np .ndarray , Khat_r : np .ndarray , Ky : int , Kx : int , fft_shape ):
295+ def rfft_convolve2d_same (
296+ images : np .ndarray , Khat_r : np .ndarray , Ky : int , Kx : int , fft_shape
297+ ):
294298 """
295299 Batched real FFT convolution, returning 'same' output.
296300
@@ -305,21 +309,25 @@ def rfft_convolve2d_same(images: np.ndarray, Khat_r: np.ndarray, Ky: int, Kx: in
305309 Fy , Fx = fft_shape
306310
307311 images_pad = jnp .pad (images , ((0 , 0 ), (0 , Fy - Hy ), (0 , Fx - Hx )))
308- Fhat = jnp .fft .rfft2 (images_pad , s = (Fy , Fx )) # (B, Fy, Fx//2+1)
312+ Fhat = jnp .fft .rfft2 (images_pad , s = (Fy , Fx )) # (B, Fy, Fx//2+1)
309313 out_pad = jnp .fft .irfft2 (Fhat * Khat_r [None , :, :], s = (Fy , Fx )) # (B, Fy, Fx), real
310314
311315 cy , cx = Ky // 2 , Kx // 2
312- return out_pad [:, cy :cy + Hy , cx :cx + Hx ]
313-
316+ return out_pad [:, cy : cy + Hy , cx : cx + Hx ]
314317
315318
316319def curvature_matrix_diag_via_w_tilde_from (
317320 inv_noise_var ,
318- rows , cols , vals ,
319- y_shape : int , x_shape : int ,
321+ rows ,
322+ cols ,
323+ vals ,
324+ y_shape : int ,
325+ x_shape : int ,
320326 S : int ,
321- Khat_r , Khat_flip_r ,
322- Ky : int , Kx : int ,
327+ Khat_r ,
328+ Khat_flip_r ,
329+ Ky : int ,
330+ Kx : int ,
323331 batch_size : int = 32 ,
324332):
325333 from jax import lax
@@ -353,14 +361,14 @@ def body(block_i, C):
353361
354362 in_block = (cols >= start ) & (cols < (start + batch_size ))
355363 bc = jnp .where (in_block , cols - start , 0 ).astype (jnp .int32 )
356- v = jnp .where (in_block , vals , 0.0 )
364+ v = jnp .where (in_block , vals , 0.0 )
357365
358366 F = jnp .zeros ((M , batch_size ), dtype = jnp .float64 )
359367 F = F .at [rows , bc ].add (v )
360368
361369 G = apply_W (F ) # (M, batch_size)
362370
363- contrib = vals [:, None ] * G [rows , :] # (nnz, batch_size)
371+ contrib = vals [:, None ] * G [rows , :] # (nnz, batch_size)
364372 Cblock = segment_sum (contrib , cols , num_segments = S ) # (S, batch_size)
365373
366374 # Mask out unused columns in last block (optional but nice)
@@ -373,13 +381,14 @@ def body(block_i, C):
373381 return C
374382
375383 C_pad = lax .fori_loop (0 , n_blocks , body , C0 )
376- C = C_pad [:, :S ] # <-- slice back to true width
384+ C = C_pad [:, :S ] # <-- slice back to true width
377385
378386 return 0.5 * (C + C .T )
379387
380388
381-
382- def curvature_matrix_diag_via_w_tilde_from_func (psf : np .ndarray , y_shape : int , x_shape : int ):
389+ def curvature_matrix_diag_via_w_tilde_from_func (
390+ psf : np .ndarray , y_shape : int , x_shape : int
391+ ):
383392
384393 import jax
385394 import jax .numpy as jnp
@@ -396,22 +405,32 @@ def curvature_matrix_diag_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x
396405
397406 # Jit wrapper with static shapes
398407 curvature_jit = jax .jit (
399- partial (curvature_matrix_diag_via_w_tilde_from , Khat_r = Khat_r , Khat_flip_r = Khat_flip_r , Ky = Ky , Kx = Kx ),
408+ partial (
409+ curvature_matrix_diag_via_w_tilde_from ,
410+ Khat_r = Khat_r ,
411+ Khat_flip_r = Khat_flip_r ,
412+ Ky = Ky ,
413+ Kx = Kx ,
414+ ),
400415 static_argnames = ("y_shape" , "x_shape" , "S" , "batch_size" ),
401416 )
402417 return curvature_jit
403418
404419
405420def curvature_matrix_off_diag_via_w_tilde_from (
406- inv_noise_var , # (Hy, Hx) float64
407- rows0 , cols0 , vals0 ,
408- rows1 , cols1 , vals1 ,
421+ inv_noise_var , # (Hy, Hx) float64
422+ rows0 ,
423+ cols0 ,
424+ vals0 ,
425+ rows1 ,
426+ cols1 ,
427+ vals1 ,
409428 y_shape : int ,
410429 x_shape : int ,
411430 S0 : int ,
412431 S1 : int ,
413- Khat_r , # rfft2(psf padded)
414- Khat_flip_r , # rfft2(flipped psf padded)
432+ Khat_r , # rfft2(psf padded)
433+ Khat_flip_r , # rfft2(flipped psf padded)
415434 Ky : int ,
416435 Kx : int ,
417436 batch_size : int = 32 ,
@@ -463,7 +482,7 @@ def body(block_i, F01):
463482 # Select mapper-1 entries in this column block
464483 in_block = (cols1 >= start ) & (cols1 < (start + batch_size ))
465484 bc = jnp .where (in_block , cols1 - start , 0 ).astype (jnp .int32 )
466- v = jnp .where (in_block , vals1 , 0.0 )
485+ v = jnp .where (in_block , vals1 , 0.0 )
467486
468487 # Assemble RHS block: (M, batch_size)
469488 Fbatch = jnp .zeros ((M , batch_size ), dtype = jnp .float64 )
@@ -491,7 +510,9 @@ def body(block_i, F01):
491510 return F01_pad [:, :S1 ]
492511
493512
494- def build_curvature_matrix_off_diag_via_w_tilde_from_func (psf : np .ndarray , y_shape : int , x_shape : int ):
513+ def build_curvature_matrix_off_diag_via_w_tilde_from_func (
514+ psf : np .ndarray , y_shape : int , x_shape : int
515+ ):
495516 """
496517 Matches your diagonal curvature_matrix_diag_via_w_tilde_from_func:
497518 - precomputes Khat_r and Khat_flip_r once
@@ -522,13 +543,15 @@ def build_curvature_matrix_off_diag_via_w_tilde_from_func(psf: np.ndarray, y_sha
522543
523544
524545def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from (
525- curvature_weights , # (M_pix, n_funcs) = (H B) / noise^2 on slim grid
546+ curvature_weights , # (M_pix, n_funcs) = (H B) / noise^2 on slim grid
526547 fft_index_for_masked_pixel , # (M_pix,) slim -> rect(flat) indices
527- rows , cols , vals , # triplets for sparse mapper A
548+ rows ,
549+ cols ,
550+ vals , # triplets for sparse mapper A
528551 y_shape : int ,
529552 x_shape : int ,
530553 S : int ,
531- Khat_flip_r , # precomputed rfft2(flipped PSF padded)
554+ Khat_flip_r , # precomputed rfft2(flipped PSF padded)
532555 Ky : int ,
533556 Kx : int ,
534557):
@@ -542,7 +565,9 @@ def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from(
542565 from jax .ops import segment_sum
543566
544567 curvature_weights = jnp .asarray (curvature_weights , dtype = jnp .float64 )
545- fft_index_for_masked_pixel = jnp .asarray (fft_index_for_masked_pixel , dtype = jnp .int32 )
568+ fft_index_for_masked_pixel = jnp .asarray (
569+ fft_index_for_masked_pixel , dtype = jnp .int32
570+ )
546571
547572 rows = jnp .asarray (rows , dtype = jnp .int32 )
548573 cols = jnp .asarray (cols , dtype = jnp .int32 )
@@ -561,16 +586,18 @@ def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from(
561586 back_native = rfft_convolve2d_same (images , Khat_flip_r , Ky , Kx , fft_shape )
562587
563588 # 3) gather at mapper rows
564- back_flat = back_native .reshape ((n_funcs , M_rect )).T # (M_rect, n_funcs)
565- back_at_rows = back_flat [rows , :] # (nnz, n_funcs)
589+ back_flat = back_native .reshape ((n_funcs , M_rect )).T # (M_rect, n_funcs)
590+ back_at_rows = back_flat [rows , :] # (nnz, n_funcs)
566591
567592 # 4) accumulate into sparse pixels
568593 contrib = vals [:, None ] * back_at_rows
569- off_diag = segment_sum (contrib , cols , num_segments = S ) # (S, n_funcs)
594+ off_diag = segment_sum (contrib , cols , num_segments = S ) # (S, n_funcs)
570595 return off_diag
571596
572597
573- def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func (psf : np .ndarray , y_shape : int , x_shape : int ):
598+ def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func (
599+ psf : np .ndarray , y_shape : int , x_shape : int
600+ ):
574601
575602 import jax
576603 import jax .numpy as jnp
@@ -595,12 +622,12 @@ def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(ps
595622
596623
597624def mapped_image_rect_from_triplets (
598- reconstruction , # (S,)
625+ reconstruction , # (S,)
599626 rows ,
600627 cols ,
601- vals , # (nnz,)
628+ vals , # (nnz,)
602629 fft_index_for_masked_pixel ,
603- data_shape : int , # y_shape * x_shape
630+ data_shape : int , # y_shape * x_shape
604631):
605632 import jax .numpy as jnp
606633 from jax .ops import segment_sum
@@ -610,8 +637,10 @@ def mapped_image_rect_from_triplets(
610637 cols = jnp .asarray (cols , dtype = jnp .int32 )
611638 vals = jnp .asarray (vals , dtype = jnp .float64 )
612639
613- contrib = vals * reconstruction [cols ] # (nnz,)
614- image_rect = segment_sum (contrib , rows , num_segments = data_shape [0 ] * data_shape [1 ]) # (M_rect,)
640+ contrib = vals * reconstruction [cols ] # (nnz,)
641+ image_rect = segment_sum (
642+ contrib , rows , num_segments = data_shape [0 ] * data_shape [1 ]
643+ ) # (M_rect,)
615644
616- image_slim = image_rect [fft_index_for_masked_pixel ] # (M_pix,)
645+ image_slim = image_rect [fft_index_for_masked_pixel ] # (M_pix,)
617646 return image_slim
0 commit comments