Skip to content

Commit a84e98c

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent 09d81cf commit a84e98c

File tree

10 files changed

+153
-128
lines changed

10 files changed

+153
-128
lines changed

autoarray/dataset/abstract/w_tilde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class AbstractWTilde:
7-
def __init__(self, curvature_preload : np.ndarray, fft_mask: np.ndarray):
7+
def __init__(self, curvature_preload: np.ndarray, fft_mask: np.ndarray):
88
"""
99
Packages together all derived data quantities necessary to fit `data (e.g. `Imaging`, Interferometer`) using
1010
an ` Inversion` via the w_tilde formalism.
@@ -52,4 +52,4 @@ def fft_index_for_masked_pixel(self) -> np.ndarray:
5252
- This method is intentionally backend-agnostic and can be used by both
5353
imaging and interferometer curvature pipelines.
5454
"""
55-
self.fft_mask.fft_index_for_masked_pixel
55+
self.fft_mask.fft_index_for_masked_pixel

autoarray/dataset/imaging/w_tilde.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(
1616
noise_map: np.ndarray,
1717
psf: np.ndarray,
1818
fft_mask: np.ndarray,
19-
batch_size: int = 128
19+
batch_size: int = 128,
2020
):
2121
"""
2222
Packages together all derived data quantities necessary to fit `Imaging` data using an ` Inversion` via the
@@ -38,10 +38,7 @@ def __init__(
3838
The lengths of how many indexes each curvature preload contains, again used to compute the curvature
3939
matrix efficienctly.
4040
"""
41-
super().__init__(
42-
curvature_preload=None,
43-
fft_mask=fft_mask
44-
)
41+
super().__init__(curvature_preload=None, fft_mask=fft_mask)
4542

4643
self.data = data
4744
self.noise_map = noise_map
@@ -50,7 +47,7 @@ def __init__(
5047
self.data_native = data.native
5148
self.noise_map_native = noise_map.native
5249

53-
self.inv_noise_var = inversion_imaging_util.build_inv_noise_var(
50+
self.inv_noise_var = inversion_imaging_util.build_inv_noise_var(
5451
noise=self.noise_map.native
5552
)
5653
self.inv_noise_var[self.data.mask] = 0.0
@@ -59,23 +56,24 @@ def __init__(
5956

6057
self.inv_noise_var = jnp.asarray(self.inv_noise_var, dtype=jnp.float64)
6158

62-
self.curvature_matrix_diag_func = (inversion_imaging_util.curvature_matrix_diag_via_w_tilde_from_func(
63-
psf=self.psf.native.array,
64-
y_shape=data.shape_native[0],
65-
x_shape=data.shape_native[1],
66-
))
59+
self.curvature_matrix_diag_func = (
60+
inversion_imaging_util.curvature_matrix_diag_via_w_tilde_from_func(
61+
psf=self.psf.native.array,
62+
y_shape=data.shape_native[0],
63+
x_shape=data.shape_native[1],
64+
)
65+
)
6766

68-
self.curvature_matrix_off_diag_func = (inversion_imaging_util.build_curvature_matrix_off_diag_via_w_tilde_from_func(
67+
self.curvature_matrix_off_diag_func = inversion_imaging_util.build_curvature_matrix_off_diag_via_w_tilde_from_func(
6968
psf=self.psf.native.array,
7069
y_shape=data.shape_native[0],
7170
x_shape=data.shape_native[1],
72-
))
71+
)
7372

74-
self.curvature_matrix_off_diag_light_profiles_func = (inversion_imaging_util.build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(
73+
self.curvature_matrix_off_diag_light_profiles_func = inversion_imaging_util.build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(
7574
psf=self.psf.native.array,
7675
y_shape=data.shape_native[0],
7776
x_shape=data.shape_native[1],
78-
))
79-
77+
)
8078

8179
self.batch_size = batch_size

autoarray/dataset/interferometer/w_tilde.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,7 @@ def __init__(
236236
The size of batches used to compute the w-tilde curvature matrix via FFT-based convolution,
237237
which can be reduced to produce lower memory usage at the cost of speed.
238238
"""
239-
super().__init__(
240-
curvature_preload=curvature_preload, fft_mask=fft_mask
241-
)
239+
super().__init__(curvature_preload=curvature_preload, fft_mask=fft_mask)
242240

243241
self.dirty_image = dirty_image
244242

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import partial
33
from typing import Optional, List
44

5+
56
def w_tilde_data_imaging_from(
67
image_native: np.ndarray,
78
noise_map_native: np.ndarray,
@@ -74,11 +75,11 @@ def w_tilde_data_imaging_from(
7475

7576

7677
def data_vector_via_w_tilde_from(
77-
w_tilde_data: np.ndarray, # (M_pix,) float64
78-
rows: np.ndarray, # (nnz,) int32 each triplet's data pixel (slim index)
79-
cols: np.ndarray, # (nnz,) int32 source pixel index
80-
vals: np.ndarray, # (nnz,) float64 mapping weights incl sub_fraction
81-
S: int, # number of source pixels
78+
w_tilde_data: np.ndarray, # (M_pix,) float64
79+
rows: np.ndarray, # (nnz,) int32 each triplet's data pixel (slim index)
80+
cols: np.ndarray, # (nnz,) int32 source pixel index
81+
vals: np.ndarray, # (nnz,) float64 mapping weights incl sub_fraction
82+
S: int, # number of source pixels
8283
) -> np.ndarray:
8384
"""
8485
Replacement for numba data_vector_via_w_tilde_data_imaging_from using triplets.
@@ -91,8 +92,8 @@ def data_vector_via_w_tilde_from(
9192
"""
9293
from jax.ops import segment_sum
9394

94-
w = w_tilde_data[rows] # (nnz,)
95-
contrib = vals * w # (nnz,)
95+
w = w_tilde_data[rows] # (nnz,)
96+
contrib = vals * w # (nnz,)
9697
return segment_sum(contrib, cols, num_segments=S) # (S,)
9798

9899

@@ -215,6 +216,7 @@ def curvature_matrix_mirrored_from(
215216

216217
return curvature_matrix_mirrored
217218

219+
218220
def curvature_matrix_with_added_to_diag_from(
219221
curvature_matrix,
220222
value: float,
@@ -271,7 +273,7 @@ def curvature_matrix_with_added_to_diag_from(
271273
def build_inv_noise_var(noise):
272274
inv = np.zeros_like(noise, dtype=np.float64)
273275
good = np.isfinite(noise) & (noise > 0)
274-
inv[good] = 1.0 / noise[good]**2
276+
inv[good] = 1.0 / noise[good] ** 2
275277
return inv
276278

277279

@@ -290,7 +292,9 @@ def precompute_Khat_rfft(kernel_2d: np.ndarray, fft_shape):
290292
return jnp.fft.rfft2(kernel_pad, s=(Fy, Fx))
291293

292294

293-
def rfft_convolve2d_same(images: np.ndarray, Khat_r: np.ndarray, Ky: int, Kx: int, fft_shape):
295+
def rfft_convolve2d_same(
296+
images: np.ndarray, Khat_r: np.ndarray, Ky: int, Kx: int, fft_shape
297+
):
294298
"""
295299
Batched real FFT convolution, returning 'same' output.
296300
@@ -305,21 +309,25 @@ def rfft_convolve2d_same(images: np.ndarray, Khat_r: np.ndarray, Ky: int, Kx: in
305309
Fy, Fx = fft_shape
306310

307311
images_pad = jnp.pad(images, ((0, 0), (0, Fy - Hy), (0, Fx - Hx)))
308-
Fhat = jnp.fft.rfft2(images_pad, s=(Fy, Fx)) # (B, Fy, Fx//2+1)
312+
Fhat = jnp.fft.rfft2(images_pad, s=(Fy, Fx)) # (B, Fy, Fx//2+1)
309313
out_pad = jnp.fft.irfft2(Fhat * Khat_r[None, :, :], s=(Fy, Fx)) # (B, Fy, Fx), real
310314

311315
cy, cx = Ky // 2, Kx // 2
312-
return out_pad[:, cy:cy + Hy, cx:cx + Hx]
313-
316+
return out_pad[:, cy : cy + Hy, cx : cx + Hx]
314317

315318

316319
def curvature_matrix_diag_via_w_tilde_from(
317320
inv_noise_var,
318-
rows, cols, vals,
319-
y_shape: int, x_shape: int,
321+
rows,
322+
cols,
323+
vals,
324+
y_shape: int,
325+
x_shape: int,
320326
S: int,
321-
Khat_r, Khat_flip_r,
322-
Ky: int, Kx: int,
327+
Khat_r,
328+
Khat_flip_r,
329+
Ky: int,
330+
Kx: int,
323331
batch_size: int = 32,
324332
):
325333
from jax import lax
@@ -353,14 +361,14 @@ def body(block_i, C):
353361

354362
in_block = (cols >= start) & (cols < (start + batch_size))
355363
bc = jnp.where(in_block, cols - start, 0).astype(jnp.int32)
356-
v = jnp.where(in_block, vals, 0.0)
364+
v = jnp.where(in_block, vals, 0.0)
357365

358366
F = jnp.zeros((M, batch_size), dtype=jnp.float64)
359367
F = F.at[rows, bc].add(v)
360368

361369
G = apply_W(F) # (M, batch_size)
362370

363-
contrib = vals[:, None] * G[rows, :] # (nnz, batch_size)
371+
contrib = vals[:, None] * G[rows, :] # (nnz, batch_size)
364372
Cblock = segment_sum(contrib, cols, num_segments=S) # (S, batch_size)
365373

366374
# Mask out unused columns in last block (optional but nice)
@@ -373,13 +381,14 @@ def body(block_i, C):
373381
return C
374382

375383
C_pad = lax.fori_loop(0, n_blocks, body, C0)
376-
C = C_pad[:, :S] # <-- slice back to true width
384+
C = C_pad[:, :S] # <-- slice back to true width
377385

378386
return 0.5 * (C + C.T)
379387

380388

381-
382-
def curvature_matrix_diag_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x_shape: int):
389+
def curvature_matrix_diag_via_w_tilde_from_func(
390+
psf: np.ndarray, y_shape: int, x_shape: int
391+
):
383392

384393
import jax
385394
import jax.numpy as jnp
@@ -396,22 +405,32 @@ def curvature_matrix_diag_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x
396405

397406
# Jit wrapper with static shapes
398407
curvature_jit = jax.jit(
399-
partial(curvature_matrix_diag_via_w_tilde_from, Khat_r=Khat_r, Khat_flip_r=Khat_flip_r, Ky=Ky, Kx=Kx),
408+
partial(
409+
curvature_matrix_diag_via_w_tilde_from,
410+
Khat_r=Khat_r,
411+
Khat_flip_r=Khat_flip_r,
412+
Ky=Ky,
413+
Kx=Kx,
414+
),
400415
static_argnames=("y_shape", "x_shape", "S", "batch_size"),
401416
)
402417
return curvature_jit
403418

404419

405420
def curvature_matrix_off_diag_via_w_tilde_from(
406-
inv_noise_var, # (Hy, Hx) float64
407-
rows0, cols0, vals0,
408-
rows1, cols1, vals1,
421+
inv_noise_var, # (Hy, Hx) float64
422+
rows0,
423+
cols0,
424+
vals0,
425+
rows1,
426+
cols1,
427+
vals1,
409428
y_shape: int,
410429
x_shape: int,
411430
S0: int,
412431
S1: int,
413-
Khat_r, # rfft2(psf padded)
414-
Khat_flip_r, # rfft2(flipped psf padded)
432+
Khat_r, # rfft2(psf padded)
433+
Khat_flip_r, # rfft2(flipped psf padded)
415434
Ky: int,
416435
Kx: int,
417436
batch_size: int = 32,
@@ -463,7 +482,7 @@ def body(block_i, F01):
463482
# Select mapper-1 entries in this column block
464483
in_block = (cols1 >= start) & (cols1 < (start + batch_size))
465484
bc = jnp.where(in_block, cols1 - start, 0).astype(jnp.int32)
466-
v = jnp.where(in_block, vals1, 0.0)
485+
v = jnp.where(in_block, vals1, 0.0)
467486

468487
# Assemble RHS block: (M, batch_size)
469488
Fbatch = jnp.zeros((M, batch_size), dtype=jnp.float64)
@@ -491,7 +510,9 @@ def body(block_i, F01):
491510
return F01_pad[:, :S1]
492511

493512

494-
def build_curvature_matrix_off_diag_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x_shape: int):
513+
def build_curvature_matrix_off_diag_via_w_tilde_from_func(
514+
psf: np.ndarray, y_shape: int, x_shape: int
515+
):
495516
"""
496517
Matches your diagonal curvature_matrix_diag_via_w_tilde_from_func:
497518
- precomputes Khat_r and Khat_flip_r once
@@ -522,13 +543,15 @@ def build_curvature_matrix_off_diag_via_w_tilde_from_func(psf: np.ndarray, y_sha
522543

523544

524545
def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from(
525-
curvature_weights, # (M_pix, n_funcs) = (H B) / noise^2 on slim grid
546+
curvature_weights, # (M_pix, n_funcs) = (H B) / noise^2 on slim grid
526547
fft_index_for_masked_pixel, # (M_pix,) slim -> rect(flat) indices
527-
rows, cols, vals, # triplets for sparse mapper A
548+
rows,
549+
cols,
550+
vals, # triplets for sparse mapper A
528551
y_shape: int,
529552
x_shape: int,
530553
S: int,
531-
Khat_flip_r, # precomputed rfft2(flipped PSF padded)
554+
Khat_flip_r, # precomputed rfft2(flipped PSF padded)
532555
Ky: int,
533556
Kx: int,
534557
):
@@ -542,7 +565,9 @@ def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from(
542565
from jax.ops import segment_sum
543566

544567
curvature_weights = jnp.asarray(curvature_weights, dtype=jnp.float64)
545-
fft_index_for_masked_pixel = jnp.asarray(fft_index_for_masked_pixel, dtype=jnp.int32)
568+
fft_index_for_masked_pixel = jnp.asarray(
569+
fft_index_for_masked_pixel, dtype=jnp.int32
570+
)
546571

547572
rows = jnp.asarray(rows, dtype=jnp.int32)
548573
cols = jnp.asarray(cols, dtype=jnp.int32)
@@ -561,16 +586,18 @@ def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from(
561586
back_native = rfft_convolve2d_same(images, Khat_flip_r, Ky, Kx, fft_shape)
562587

563588
# 3) gather at mapper rows
564-
back_flat = back_native.reshape((n_funcs, M_rect)).T # (M_rect, n_funcs)
565-
back_at_rows = back_flat[rows, :] # (nnz, n_funcs)
589+
back_flat = back_native.reshape((n_funcs, M_rect)).T # (M_rect, n_funcs)
590+
back_at_rows = back_flat[rows, :] # (nnz, n_funcs)
566591

567592
# 4) accumulate into sparse pixels
568593
contrib = vals[:, None] * back_at_rows
569-
off_diag = segment_sum(contrib, cols, num_segments=S) # (S, n_funcs)
594+
off_diag = segment_sum(contrib, cols, num_segments=S) # (S, n_funcs)
570595
return off_diag
571596

572597

573-
def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x_shape: int):
598+
def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(
599+
psf: np.ndarray, y_shape: int, x_shape: int
600+
):
574601

575602
import jax
576603
import jax.numpy as jnp
@@ -595,12 +622,12 @@ def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(ps
595622

596623

597624
def mapped_image_rect_from_triplets(
598-
reconstruction, # (S,)
625+
reconstruction, # (S,)
599626
rows,
600627
cols,
601-
vals, # (nnz,)
628+
vals, # (nnz,)
602629
fft_index_for_masked_pixel,
603-
data_shape: int, # y_shape * x_shape
630+
data_shape: int, # y_shape * x_shape
604631
):
605632
import jax.numpy as jnp
606633
from jax.ops import segment_sum
@@ -610,8 +637,10 @@ def mapped_image_rect_from_triplets(
610637
cols = jnp.asarray(cols, dtype=jnp.int32)
611638
vals = jnp.asarray(vals, dtype=jnp.float64)
612639

613-
contrib = vals * reconstruction[cols] # (nnz,)
614-
image_rect = segment_sum(contrib, rows, num_segments=data_shape[0] * data_shape[1]) # (M_rect,)
640+
contrib = vals * reconstruction[cols] # (nnz,)
641+
image_rect = segment_sum(
642+
contrib, rows, num_segments=data_shape[0] * data_shape[1]
643+
) # (M_rect,)
615644

616-
image_slim = image_rect[fft_index_for_masked_pixel] # (M_pix,)
645+
image_slim = image_rect[fft_index_for_masked_pixel] # (M_pix,)
617646
return image_slim

0 commit comments

Comments
 (0)