11import numpy as np
22from functools import partial
3-
3+ from typing import Optional , List
44
55def w_tilde_data_imaging_from (
66 image_native : np .ndarray ,
@@ -215,8 +215,6 @@ def curvature_matrix_mirrored_from(
215215
216216 return curvature_matrix_mirrored
217217
218- from typing import Optional , List
219-
220218def curvature_matrix_with_added_to_diag_from (
221219 curvature_matrix ,
222220 value : float ,
@@ -381,7 +379,7 @@ def body(block_i, C):
381379
382380
383381
384- def build_curvature_rfft_fn (psf : np .ndarray , y_shape : int , x_shape : int ):
382+ def curvature_matrix_diag_via_w_tilde_from_func (psf : np .ndarray , y_shape : int , x_shape : int ):
385383
386384 import jax
387385 import jax .numpy as jnp
@@ -404,6 +402,123 @@ def build_curvature_rfft_fn(psf: np.ndarray, y_shape: int, x_shape: int):
404402 return curvature_jit
405403
406404
405+ def curvature_matrix_off_diag_via_w_tilde_from (
406+ inv_noise_var , # (Hy, Hx) float64
407+ rows0 , cols0 , vals0 ,
408+ rows1 , cols1 , vals1 ,
409+ y_shape : int ,
410+ x_shape : int ,
411+ S0 : int ,
412+ S1 : int ,
413+ Khat_r , # rfft2(psf padded)
414+ Khat_flip_r , # rfft2(flipped psf padded)
415+ Ky : int ,
416+ Kx : int ,
417+ batch_size : int = 32 ,
418+ ):
419+ """
420+ Off-diagonal curvature block:
421+ F01 = A0^T W A1
422+ Returns: (S0, S1)
423+ """
424+
425+ import jax .numpy as jnp
426+ from jax import lax
427+ from jax .ops import segment_sum
428+
429+ inv_noise_var = jnp .asarray (inv_noise_var , dtype = jnp .float64 )
430+
431+ rows0 = jnp .asarray (rows0 , dtype = jnp .int32 )
432+ cols0 = jnp .asarray (cols0 , dtype = jnp .int32 )
433+ vals0 = jnp .asarray (vals0 , dtype = jnp .float64 )
434+
435+ rows1 = jnp .asarray (rows1 , dtype = jnp .int32 )
436+ cols1 = jnp .asarray (cols1 , dtype = jnp .int32 )
437+ vals1 = jnp .asarray (vals1 , dtype = jnp .float64 )
438+
439+ M = y_shape * x_shape
440+ fft_shape = (y_shape + Ky - 1 , x_shape + Kx - 1 )
441+
442+ def apply_W (Fbatch_flat : jnp .ndarray ) -> jnp .ndarray :
443+ B = Fbatch_flat .shape [1 ]
444+ Fimg = Fbatch_flat .T .reshape ((B , y_shape , x_shape ))
445+ blurred = rfft_convolve2d_same (Fimg , Khat_r , Ky , Kx , fft_shape )
446+ weighted = blurred * inv_noise_var [None , :, :]
447+ back = rfft_convolve2d_same (weighted , Khat_flip_r , Ky , Kx , fft_shape )
448+ return back .reshape ((B , M )).T # (M, B)
449+
450+ # -----------------------------
451+ # FIX: pad output width so dynamic_update_slice never clamps
452+ # -----------------------------
453+ n_blocks = (S1 + batch_size - 1 ) // batch_size
454+ S1_pad = n_blocks * batch_size
455+
456+ F01_0 = jnp .zeros ((S0 , S1_pad ), dtype = jnp .float64 )
457+
458+ col_offsets = jnp .arange (batch_size , dtype = jnp .int32 )
459+
460+ def body (block_i , F01 ):
461+ start = block_i * batch_size
462+
463+ # Select mapper-1 entries in this column block
464+ in_block = (cols1 >= start ) & (cols1 < (start + batch_size ))
465+ bc = jnp .where (in_block , cols1 - start , 0 ).astype (jnp .int32 )
466+ v = jnp .where (in_block , vals1 , 0.0 )
467+
468+ # Assemble RHS block: (M, batch_size)
469+ Fbatch = jnp .zeros ((M , batch_size ), dtype = jnp .float64 )
470+ Fbatch = Fbatch .at [rows1 , bc ].add (v )
471+
472+ # Apply W
473+ Gbatch = apply_W (Fbatch ) # (M, batch_size)
474+
475+ # Project with A0^T -> (S0, batch_size)
476+ contrib = vals0 [:, None ] * Gbatch [rows0 , :]
477+ block = segment_sum (contrib , cols0 , num_segments = S0 ) # (S0, batch_size)
478+
479+ # Mask out columns beyond S1 in the last block
480+ width = jnp .minimum (batch_size , jnp .maximum (0 , S1 - start ))
481+ mask = (col_offsets < width ).astype (jnp .float64 )
482+ block = block * mask [None , :]
483+
484+ # Safe because start+batch_size <= S1_pad always
485+ F01 = lax .dynamic_update_slice (F01 , block , (0 , start ))
486+ return F01
487+
488+ F01_pad = lax .fori_loop (0 , n_blocks , body , F01_0 )
489+
490+ # Slice back to true width
491+ return F01_pad [:, :S1 ]
492+
493+
494+ def build_curvature_matrix_off_diag_via_w_tilde_from_func (psf : np .ndarray , y_shape : int , x_shape : int ):
495+ """
496+ Matches your diagonal curvature_matrix_diag_via_w_tilde_from_func:
497+ - precomputes Khat_r and Khat_flip_r once
498+ - returns a jitted function with the SAME static args pattern
499+ """
500+
501+ import jax
502+ import jax .numpy as jnp
503+
504+ psf = jnp .asarray (psf , dtype = jnp .float64 )
505+ Ky , Kx = psf .shape
506+ fft_shape = (y_shape + Ky - 1 , x_shape + Kx - 1 )
507+
508+ Khat_r = precompute_Khat_rfft (psf , fft_shape )
509+ Khat_flip_r = precompute_Khat_rfft (jnp .flip (psf , axis = (0 , 1 )), fft_shape )
510+
511+ offdiag_jit = jax .jit (
512+ partial (
513+ curvature_matrix_off_diag_via_w_tilde_from ,
514+ Khat_r = Khat_r ,
515+ Khat_flip_r = Khat_flip_r ,
516+ Ky = Ky ,
517+ Kx = Kx ,
518+ ),
519+ static_argnames = ("y_shape" , "x_shape" , "S0" , "S1" , "batch_size" ),
520+ )
521+ return offdiag_jit
407522
408523
409524def mapped_image_rect_from_triplets (
0 commit comments