Skip to content

Commit 77c86d3

Browse files
Jammy2211Jammy2211
authored andcommitted
x2 mapper stuff now works
1 parent 5acea80 commit 77c86d3

File tree

4 files changed

+150
-50
lines changed

4 files changed

+150
-50
lines changed

autoarray/dataset/imaging/w_tilde.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,16 @@ def __init__(
5959

6060
self.inv_noise_var = jnp.asarray(self.inv_noise_var, dtype=jnp.float64)
6161

62-
self.curv_fn = (inversion_imaging_util.build_curvature_rfft_fn(
62+
self.curvature_matrix_diag_func = (inversion_imaging_util.curvature_matrix_diag_via_w_tilde_from_func(
6363
psf=self.psf.native.array,
6464
y_shape=data.shape_native[0],
6565
x_shape=data.shape_native[1],
6666
))
6767

68-
self.batch_size = batch_size
69-
70-
71-
@property
72-
def psf_operator_matrix_dense(self):
68+
self.curvature_matrix_off_diag_func = (inversion_imaging_util.build_curvature_matrix_off_diag_via_w_tilde_from_func(
69+
psf=self.psf.native.array,
70+
y_shape=data.shape_native[0],
71+
x_shape=data.shape_native[1],
72+
))
7373

74-
return inversion_imaging_util.psf_operator_matrix_dense_from(
75-
kernel_native=self.psf.native.array,
76-
native_index_for_slim_index=np.array(
77-
self.mask.derive_indexes.native_for_slim
78-
).astype("int"),
79-
native_shape=self.noise_map.shape_native,
80-
correlate=False,
81-
)
74+
self.batch_size = batch_size

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from functools import partial
3-
3+
from typing import Optional, List
44

55
def w_tilde_data_imaging_from(
66
image_native: np.ndarray,
@@ -215,8 +215,6 @@ def curvature_matrix_mirrored_from(
215215

216216
return curvature_matrix_mirrored
217217

218-
from typing import Optional, List
219-
220218
def curvature_matrix_with_added_to_diag_from(
221219
curvature_matrix,
222220
value: float,
@@ -381,7 +379,7 @@ def body(block_i, C):
381379

382380

383381

384-
def build_curvature_rfft_fn(psf: np.ndarray, y_shape: int, x_shape: int):
382+
def curvature_matrix_diag_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x_shape: int):
385383

386384
import jax
387385
import jax.numpy as jnp
@@ -404,6 +402,123 @@ def build_curvature_rfft_fn(psf: np.ndarray, y_shape: int, x_shape: int):
404402
return curvature_jit
405403

406404

405+
def curvature_matrix_off_diag_via_w_tilde_from(
406+
inv_noise_var, # (Hy, Hx) float64
407+
rows0, cols0, vals0,
408+
rows1, cols1, vals1,
409+
y_shape: int,
410+
x_shape: int,
411+
S0: int,
412+
S1: int,
413+
Khat_r, # rfft2(psf padded)
414+
Khat_flip_r, # rfft2(flipped psf padded)
415+
Ky: int,
416+
Kx: int,
417+
batch_size: int = 32,
418+
):
419+
"""
420+
Off-diagonal curvature block:
421+
F01 = A0^T W A1
422+
Returns: (S0, S1)
423+
"""
424+
425+
import jax.numpy as jnp
426+
from jax import lax
427+
from jax.ops import segment_sum
428+
429+
inv_noise_var = jnp.asarray(inv_noise_var, dtype=jnp.float64)
430+
431+
rows0 = jnp.asarray(rows0, dtype=jnp.int32)
432+
cols0 = jnp.asarray(cols0, dtype=jnp.int32)
433+
vals0 = jnp.asarray(vals0, dtype=jnp.float64)
434+
435+
rows1 = jnp.asarray(rows1, dtype=jnp.int32)
436+
cols1 = jnp.asarray(cols1, dtype=jnp.int32)
437+
vals1 = jnp.asarray(vals1, dtype=jnp.float64)
438+
439+
M = y_shape * x_shape
440+
fft_shape = (y_shape + Ky - 1, x_shape + Kx - 1)
441+
442+
def apply_W(Fbatch_flat: jnp.ndarray) -> jnp.ndarray:
443+
B = Fbatch_flat.shape[1]
444+
Fimg = Fbatch_flat.T.reshape((B, y_shape, x_shape))
445+
blurred = rfft_convolve2d_same(Fimg, Khat_r, Ky, Kx, fft_shape)
446+
weighted = blurred * inv_noise_var[None, :, :]
447+
back = rfft_convolve2d_same(weighted, Khat_flip_r, Ky, Kx, fft_shape)
448+
return back.reshape((B, M)).T # (M, B)
449+
450+
# -----------------------------
451+
# FIX: pad output width so dynamic_update_slice never clamps
452+
# -----------------------------
453+
n_blocks = (S1 + batch_size - 1) // batch_size
454+
S1_pad = n_blocks * batch_size
455+
456+
F01_0 = jnp.zeros((S0, S1_pad), dtype=jnp.float64)
457+
458+
col_offsets = jnp.arange(batch_size, dtype=jnp.int32)
459+
460+
def body(block_i, F01):
461+
start = block_i * batch_size
462+
463+
# Select mapper-1 entries in this column block
464+
in_block = (cols1 >= start) & (cols1 < (start + batch_size))
465+
bc = jnp.where(in_block, cols1 - start, 0).astype(jnp.int32)
466+
v = jnp.where(in_block, vals1, 0.0)
467+
468+
# Assemble RHS block: (M, batch_size)
469+
Fbatch = jnp.zeros((M, batch_size), dtype=jnp.float64)
470+
Fbatch = Fbatch.at[rows1, bc].add(v)
471+
472+
# Apply W
473+
Gbatch = apply_W(Fbatch) # (M, batch_size)
474+
475+
# Project with A0^T -> (S0, batch_size)
476+
contrib = vals0[:, None] * Gbatch[rows0, :]
477+
block = segment_sum(contrib, cols0, num_segments=S0) # (S0, batch_size)
478+
479+
# Mask out columns beyond S1 in the last block
480+
width = jnp.minimum(batch_size, jnp.maximum(0, S1 - start))
481+
mask = (col_offsets < width).astype(jnp.float64)
482+
block = block * mask[None, :]
483+
484+
# Safe because start+batch_size <= S1_pad always
485+
F01 = lax.dynamic_update_slice(F01, block, (0, start))
486+
return F01
487+
488+
F01_pad = lax.fori_loop(0, n_blocks, body, F01_0)
489+
490+
# Slice back to true width
491+
return F01_pad[:, :S1]
492+
493+
494+
def build_curvature_matrix_off_diag_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x_shape: int):
495+
"""
496+
Matches your diagonal curvature_matrix_diag_via_w_tilde_from_func:
497+
- precomputes Khat_r and Khat_flip_r once
498+
- returns a jitted function with the SAME static args pattern
499+
"""
500+
501+
import jax
502+
import jax.numpy as jnp
503+
504+
psf = jnp.asarray(psf, dtype=jnp.float64)
505+
Ky, Kx = psf.shape
506+
fft_shape = (y_shape + Ky - 1, x_shape + Kx - 1)
507+
508+
Khat_r = precompute_Khat_rfft(psf, fft_shape)
509+
Khat_flip_r = precompute_Khat_rfft(jnp.flip(psf, axis=(0, 1)), fft_shape)
510+
511+
offdiag_jit = jax.jit(
512+
partial(
513+
curvature_matrix_off_diag_via_w_tilde_from,
514+
Khat_r=Khat_r,
515+
Khat_flip_r=Khat_flip_r,
516+
Ky=Ky,
517+
Kx=Kx,
518+
),
519+
static_argnames=("y_shape", "x_shape", "S0", "S1", "batch_size"),
520+
)
521+
return offdiag_jit
407522

408523

409524
def mapped_image_rect_from_triplets(

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _data_vector_multi_mapper(self) -> np.ndarray:
165165
rows=rows,
166166
cols=cols,
167167
vals=vals,
168-
S=mapper.total_params,
168+
S=mapper.params,
169169
)
170170
)
171171

@@ -281,7 +281,7 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
281281

282282
rows, cols, vals = mapper_i.pixel_triplets_curvature
283283

284-
diag = self.w_tilde.curv_fn(
284+
diag = self.w_tilde.curvature_matrix_diag_func(
285285
self.w_tilde.inv_noise_var,
286286
rows,
287287
cols,
@@ -317,35 +317,27 @@ def _curvature_matrix_off_diag_from(
317317
This function computes the off-diagonal terms of F using the w_tilde formalism.
318318
"""
319319

320-
curvature_matrix_off_diag_0 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from(
321-
curvature_preload=self.w_tilde.curvature_preload,
322-
curvature_indexes=self.w_tilde.indexes,
323-
curvature_lengths=self.w_tilde.lengths,
324-
data_to_pix_unique_0=mapper_0.unique_mappings.data_to_pix_unique,
325-
data_weights_0=mapper_0.unique_mappings.data_weights,
326-
pix_lengths_0=mapper_0.unique_mappings.pix_lengths,
327-
pix_pixels_0=mapper_0.params,
328-
data_to_pix_unique_1=mapper_1.unique_mappings.data_to_pix_unique,
329-
data_weights_1=mapper_1.unique_mappings.data_weights,
330-
pix_lengths_1=mapper_1.unique_mappings.pix_lengths,
331-
pix_pixels_1=mapper_1.params,
332-
)
333-
334-
curvature_matrix_off_diag_1 = inversion_imaging_numba_util.curvature_matrix_off_diags_via_w_tilde_curvature_preload_imaging_from(
335-
curvature_preload=self.w_tilde.curvature_preload,
336-
curvature_indexes=self.w_tilde.indexes,
337-
curvature_lengths=self.w_tilde.lengths,
338-
data_to_pix_unique_0=mapper_1.unique_mappings.data_to_pix_unique,
339-
data_weights_0=mapper_1.unique_mappings.data_weights,
340-
pix_lengths_0=mapper_1.unique_mappings.pix_lengths,
341-
pix_pixels_0=mapper_1.params,
342-
data_to_pix_unique_1=mapper_0.unique_mappings.data_to_pix_unique,
343-
data_weights_1=mapper_0.unique_mappings.data_weights,
344-
pix_lengths_1=mapper_0.unique_mappings.pix_lengths,
345-
pix_pixels_1=mapper_0.params,
346-
)
347-
348-
return curvature_matrix_off_diag_0 + curvature_matrix_off_diag_1.T
320+
rows0, cols0, vals0 = mapper_0.pixel_triplets_curvature
321+
rows1, cols1, vals1 = mapper_1.pixel_triplets_curvature
322+
323+
S0 = mapper_0.params
324+
S1 = mapper_1.params
325+
326+
(y_shape, x_shape) = self.mask.shape_native
327+
328+
return self.w_tilde.curvature_matrix_off_diag_func(
329+
inv_noise_var=self.w_tilde.inv_noise_var,
330+
rows0=rows0,
331+
cols0=cols0,
332+
vals0=vals0,
333+
rows1=rows1,
334+
cols1=cols1,
335+
vals1=vals1,
336+
y_shape=y_shape,
337+
x_shape=x_shape,
338+
S0=S0,
339+
S1=S1,
340+
)
349341

350342
@property
351343
def _curvature_matrix_x1_mapper(self) -> np.ndarray:

test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test__curvature_matrix_via_w_tilde_two_methods_agree():
290290
return_rows_slim=False,
291291
)
292292

293-
curvature_matrix_via_w_tilde = w_tilde.curv_fn(
293+
curvature_matrix_via_w_tilde = w_tilde.curvature_matrix_diag_func(
294294
w_tilde.inv_noise_var,
295295
rows,
296296
cols,

0 commit comments

Comments
 (0)