Skip to content

Commit 9e7b1de

Browse files
Jammy2211Jammy2211
authored andcommitted
fix now pointless unit ests
1 parent 77c86d3 commit 9e7b1de

File tree

6 files changed

+131
-29
lines changed

6 files changed

+131
-29
lines changed

autoarray/dataset/imaging/w_tilde.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,11 @@ def __init__(
7171
x_shape=data.shape_native[1],
7272
))
7373

74+
self.curvature_matrix_off_diag_light_profiles_func = (inversion_imaging_util.build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(
75+
psf=self.psf.native.array,
76+
y_shape=data.shape_native[0],
77+
x_shape=data.shape_native[1],
78+
))
79+
80+
7481
self.batch_size = batch_size

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,79 @@ def build_curvature_matrix_off_diag_via_w_tilde_from_func(psf: np.ndarray, y_sha
521521
return offdiag_jit
522522

523523

524+
def curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from(
525+
curvature_weights, # (M_pix, n_funcs) = (H B) / noise^2 on slim grid
526+
fft_index_for_masked_pixel, # (M_pix,) slim -> rect(flat) indices
527+
rows, cols, vals, # triplets for sparse mapper A
528+
y_shape: int,
529+
x_shape: int,
530+
S: int,
531+
Khat_flip_r, # precomputed rfft2(flipped PSF padded)
532+
Ky: int,
533+
Kx: int,
534+
):
535+
"""
536+
Computes: off_diag = A^T [ H^T(curvature_weights_native) ]
537+
where curvature_weights = (H B) / noise^2 already.
538+
"""
539+
540+
import jax
541+
import jax.numpy as jnp
542+
from jax.ops import segment_sum
543+
544+
curvature_weights = jnp.asarray(curvature_weights, dtype=jnp.float64)
545+
fft_index_for_masked_pixel = jnp.asarray(fft_index_for_masked_pixel, dtype=jnp.int32)
546+
547+
rows = jnp.asarray(rows, dtype=jnp.int32)
548+
cols = jnp.asarray(cols, dtype=jnp.int32)
549+
vals = jnp.asarray(vals, dtype=jnp.float64)
550+
551+
M_pix, n_funcs = curvature_weights.shape
552+
M_rect = y_shape * x_shape
553+
fft_shape = (y_shape + Ky - 1, x_shape + Kx - 1)
554+
555+
# 1) scatter slim weights onto rectangular grid (flat)
556+
grid_flat = jnp.zeros((M_rect, n_funcs), dtype=jnp.float64)
557+
grid_flat = grid_flat.at[fft_index_for_masked_pixel, :].set(curvature_weights)
558+
559+
# 2) apply H^T = convolution with flipped PSF (one convolution)
560+
images = grid_flat.T.reshape((n_funcs, y_shape, x_shape)) # (B=n_funcs, Hy, Hx)
561+
back_native = rfft_convolve2d_same(images, Khat_flip_r, Ky, Kx, fft_shape)
562+
563+
# 3) gather at mapper rows
564+
back_flat = back_native.reshape((n_funcs, M_rect)).T # (M_rect, n_funcs)
565+
back_at_rows = back_flat[rows, :] # (nnz, n_funcs)
566+
567+
# 4) accumulate into sparse pixels
568+
contrib = vals[:, None] * back_at_rows
569+
off_diag = segment_sum(contrib, cols, num_segments=S) # (S, n_funcs)
570+
return off_diag
571+
572+
573+
def build_curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from_func(psf: np.ndarray, y_shape: int, x_shape: int):
574+
575+
import jax
576+
import jax.numpy as jnp
577+
578+
psf = jnp.asarray(psf, dtype=jnp.float64)
579+
Ky, Kx = psf.shape
580+
fft_shape = (y_shape + Ky - 1, x_shape + Kx - 1)
581+
582+
psf_flip = jnp.flip(psf, axis=(0, 1))
583+
Khat_flip_r = precompute_Khat_rfft(psf_flip, fft_shape)
584+
585+
fn_jit = jax.jit(
586+
partial(
587+
curvature_matrix_off_diag_with_light_profiles_via_w_tilde_from,
588+
Khat_flip_r=Khat_flip_r,
589+
Ky=Ky,
590+
Kx=Kx,
591+
),
592+
static_argnames=("y_shape", "x_shape", "S"),
593+
)
594+
return fn_jit
595+
596+
524597
def mapped_image_rect_from_triplets(
525598
reconstruction, # (S,)
526599
rows,

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ def _data_vector_mapper(self) -> np.ndarray:
8888
rows, cols, vals = mapper.pixel_triplets_data
8989

9090
data_vector_mapper = (
91-
inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from(
91+
inversion_imaging_util.data_vector_via_w_tilde_from(
9292
w_tilde_data=self.w_tilde_data,
9393
rows=rows,
9494
cols=cols,
9595
vals=vals,
96-
S=mapper.total_params,
96+
S=mapper.params,
9797
)
9898
)
9999
param_range = mapper_param_range[mapper_index]
@@ -199,15 +199,21 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray:
199199
linear_func
200200
]
201201

202-
diag = inversion_imaging_numba_util.data_vector_via_blurred_mapping_matrix_from(
203-
blurred_mapping_matrix=np.array(operated_mapping_matrix),
202+
diag = inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from(
203+
blurred_mapping_matrix=operated_mapping_matrix,
204204
image=self.data.array,
205205
noise_map=self.noise_map.array,
206206
)
207207

208208
param_range = linear_func_param_range[linear_func_index]
209209

210-
data_vector[param_range[0] : param_range[1],] = diag
210+
start = param_range[0]
211+
end = param_range[1]
212+
213+
if self._xp is np:
214+
data_vector[start:end] = diag
215+
else:
216+
data_vector = data_vector.at[start:end].set(diag)
211217

212218
return data_vector
213219

@@ -421,22 +427,35 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray:
421427
/ self.noise_map[:, None] ** 2
422428
)
423429

424-
off_diag = inversion_imaging_numba_util.curvature_matrix_off_diags_via_mapper_and_linear_func_curvature_vector_from(
425-
data_to_pix_unique=mapper.unique_mappings.data_to_pix_unique,
426-
data_weights=mapper.unique_mappings.data_weights,
427-
pix_lengths=mapper.unique_mappings.pix_lengths,
428-
pix_pixels=mapper.params,
429-
curvature_weights=np.array(curvature_weights),
430-
mask=self.mask.array,
431-
psf_kernel=self.psf.native.array,
430+
rows, cols, vals = mapper.pixel_triplets_curvature
431+
432+
off_diag = self.w_tilde.curvature_matrix_off_diag_light_profiles_func(
433+
curvature_weights=curvature_weights,
434+
fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel,
435+
rows=rows,
436+
cols=cols,
437+
vals=vals,
438+
y_shape=self.mask.shape_native[0],
439+
x_shape=self.mask.shape_native[1],
440+
S=mapper.params,
432441
)
433442

434-
curvature_matrix[
435-
mapper_param_range[0] : mapper_param_range[1],
443+
if self._xp is np:
444+
445+
curvature_matrix[
446+
mapper_param_range[0] : mapper_param_range[1],
436447
linear_func_param_range[0] : linear_func_param_range[1],
437-
] = off_diag
448+
] = off_diag
449+
else:
450+
451+
curvature_matrix = curvature_matrix.at[
452+
mapper_param_range[0] : mapper_param_range[1],
453+
linear_func_param_range[0] : linear_func_param_range[1],
454+
].set(off_diag)
455+
438456

439457
for index_0, linear_func_0 in enumerate(linear_func_list):
458+
440459
linear_func_param_range_0 = linear_func_param_range_list[index_0]
441460

442461
weighted_vector_0 = (
@@ -452,15 +471,24 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray:
452471
/ self.noise_map[:, None]
453472
)
454473

455-
diag = np.dot(
474+
diag = self._xp.dot(
456475
weighted_vector_0.T,
457476
weighted_vector_1,
458477
)
459478

460-
curvature_matrix[
461-
linear_func_param_range_0[0] : linear_func_param_range_0[1],
462-
linear_func_param_range_1[0] : linear_func_param_range_1[1],
463-
] = diag
479+
if self._xp is np:
480+
481+
curvature_matrix[
482+
linear_func_param_range_0[0] : linear_func_param_range_0[1],
483+
linear_func_param_range_1[0] : linear_func_param_range_1[1],
484+
] = diag
485+
486+
else:
487+
488+
curvature_matrix = curvature_matrix.at[
489+
linear_func_param_range_0[0] : linear_func_param_range_0[1],
490+
linear_func_param_range_1[0] : linear_func_param_range_1[1],
491+
].set(diag)
464492

465493
return curvature_matrix
466494

autoarray/inversion/inversion/interferometer/w_tilde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def curvature_matrix_diag(self) -> np.ndarray:
114114
pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index,
115115
pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index,
116116
pix_pixels=self.linear_obj_list[0].params,
117-
fft_index_for_masked_pixel=self.w_tilde.fft_index_for_masked_pixel,
117+
fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel,
118118
)
119119

120120
@property

test_autoarray/dataset/imaging/test_dataset.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,6 @@ def test__apply_mask(imaging_7x7, mask_2d_7x7, psf_3x3):
193193

194194
assert type(masked_imaging_7x7.psf) == aa.Kernel2D
195195

196-
masked_imaging_7x7 = masked_imaging_7x7.apply_w_tilde()
197-
198-
assert masked_imaging_7x7.w_tilde.curvature_preload.shape == (35,)
199-
assert masked_imaging_7x7.w_tilde.indexes.shape == (35,)
200-
assert masked_imaging_7x7.w_tilde.lengths.shape == (9,)
201-
202196

203197
def test__apply_noise_scaling(imaging_7x7, mask_2d_7x7):
204198
masked_imaging_7x7 = imaging_7x7.apply_noise_scaling(

test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test__curvature_matrix_via_curvature_preload_from():
128128
fft_state=w_tilde.fft_state,
129129
pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index,
130130
pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index,
131-
fft_index_for_masked_pixel=w_tilde.fft_index_for_masked_pixel,
131+
fft_index_for_masked_pixel=grid.mask.fft_index_for_masked_pixel,
132132
pix_pixels=3,
133133
)
134134

0 commit comments

Comments
 (0)