Skip to content

Commit f9fc511

Browse files
Jammy2211Jammy2211
authored andcommitted
data vector stuff all jaxd
1 parent 6fcbd08 commit f9fc511

File tree

10 files changed

+220
-85
lines changed

10 files changed

+220
-85
lines changed

autoarray/dataset/abstract/w_tilde.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,4 @@ def fft_index_for_masked_pixel(self) -> np.ndarray:
5252
- This method is intentionally backend-agnostic and can be used by both
5353
imaging and interferometer curvature pipelines.
5454
"""
55-
56-
# Boolean mask defined on the rectangular FFT grid
57-
# True = masked pixel
58-
# False = unmasked pixel
59-
mask_fft = self.fft_mask
60-
61-
# Coordinates of unmasked pixels in the FFT grid
62-
ys, xs = np.where(~mask_fft)
63-
64-
# Width of the FFT grid (number of columns)
65-
width = mask_fft.shape[1]
66-
67-
# Convert (y, x) coordinates to flat row-major indices
68-
return (ys * width + xs).astype(np.int32)
55+
self.fft_mask.fft_index_for_masked_pixel

autoarray/dataset/imaging/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def apply_w_tilde(
505505
"""
506506

507507
w_tilde = WTildeImaging(
508+
data=self.data,
508509
noise_map=self.noise_map,
509510
psf=self.psf,
510511
fft_mask=self.mask,

autoarray/dataset/imaging/w_tilde.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
class WTildeImaging(AbstractWTilde):
1313
def __init__(
1414
self,
15+
data: np.ndarray,
1516
noise_map: np.ndarray,
1617
psf: np.ndarray,
1718
fft_mask: np.ndarray,
@@ -41,9 +42,13 @@ def __init__(
4142
fft_mask=fft_mask
4243
)
4344

45+
self.data = data
4446
self.noise_map = noise_map
4547
self.psf = psf
4648

49+
self.data_native = data.native
50+
self.noise_map_native = noise_map.native
51+
4752
@property
4853
def psf_operator_matrix_dense(self):
4954

autoarray/dataset/interferometer/w_tilde.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,6 @@ def __init__(
250250
curvature_preload=self.curvature_preload, batch_size=batch_size
251251
)
252252

253-
254-
255253
def save_curvature_preload(
256254
self,
257255
file: Union[str, Path],

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,5 @@
11
import numpy as np
22

3-
def pixel_triplets_from_subpixel_arrays_from(
4-
pix_indexes_for_sub, # (M_sub, P)
5-
pix_weights_for_sub, # (M_sub, P)
6-
slim_index_for_sub, # (M_sub,)
7-
fft_index_for_masked_pixel, # (N_unmasked,)
8-
sub_fraction_slim, # (N_unmasked,)
9-
):
10-
"""
11-
Build sparse source→image mapping triplets (rows, cols, vals)
12-
for a fixed-size interpolation stencil.
13-
14-
Assumptions:
15-
- Every subpixel maps to exactly P source pixels
16-
- All entries in pix_indexes_for_sub are valid
17-
- No padding / ragged rows needed
18-
"""
19-
import jax.numpy as jnp
20-
21-
22-
M_sub, P = pix_indexes_for_sub.shape
23-
24-
sub_ids = jnp.repeat(jnp.arange(M_sub, dtype=jnp.int32), P)
25-
26-
cols = pix_indexes_for_sub.reshape(-1).astype(jnp.int32)
27-
vals = pix_weights_for_sub.reshape(-1).astype(jnp.float64)
28-
29-
slim_rows = slim_index_for_sub[sub_ids].astype(jnp.int32)
30-
rows = fft_index_for_masked_pixel[slim_rows].astype(jnp.int32)
31-
32-
vals = vals * sub_fraction_slim[slim_rows].astype(jnp.float64)
33-
return rows, cols, vals
34-
35-
363
def psf_operator_matrix_dense_from(
374
kernel_native: np.ndarray,
385
native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels
@@ -159,6 +126,29 @@ def w_tilde_data_imaging_from(
159126
return xp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,)
160127

161128

129+
def data_vector_via_w_tilde_from(
130+
w_tilde_data: np.ndarray, # (M_pix,) float64
131+
rows: np.ndarray, # (nnz,) int32 each triplet's data pixel (slim index)
132+
cols: np.ndarray, # (nnz,) int32 source pixel index
133+
vals: np.ndarray, # (nnz,) float64 mapping weights incl sub_fraction
134+
S: int, # number of source pixels
135+
) -> np.ndarray:
136+
"""
137+
Replacement for numba data_vector_via_w_tilde_data_imaging_from using triplets.
138+
139+
Computes:
140+
D[p] = sum_{triplets t with col_t=p} vals[t] * w_tilde_data_slim[slim_rows[t]]
141+
142+
Returns:
143+
(S,) float64
144+
"""
145+
from jax.ops import segment_sum
146+
147+
w = w_tilde_data[rows] # (nnz,)
148+
contrib = vals * w # (nnz,)
149+
return segment_sum(contrib, cols, num_segments=S) # (S,)
150+
151+
162152
def data_vector_via_blurred_mapping_matrix_from(
163153
blurred_mapping_matrix: np.ndarray, image: np.ndarray, noise_map: np.ndarray
164154
) -> np.ndarray:

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,11 @@ def __init__(
5858
@cached_property
5959
def w_tilde_data(self):
6060
return inversion_imaging_util.w_tilde_data_imaging_from(
61-
image_native=self.data.native.array,
62-
noise_map_native=self.noise_map.native.array,
63-
kernel_native=self.psf.native.array,
61+
image_native=self.w_tilde.data_native.array,
62+
noise_map_native=self.w_tilde.noise_map_native.array,
63+
kernel_native=self.psf.stored_native,
6464
native_index_for_slim_index=self.data.mask.derive_indexes.native_for_slim,
65+
xp=self._xp
6566
)
6667

6768
@property
@@ -83,15 +84,16 @@ def _data_vector_mapper(self) -> np.ndarray:
8384
mapper_param_range = self.param_range_list_from(cls=AbstractMapper)
8485

8586
for mapper_index, mapper in enumerate(mapper_list):
87+
88+
rows, cols, vals = mapper.pixel_triplets
89+
8690
data_vector_mapper = (
87-
inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from(
91+
inversion_imaging_util.data_vector_via_w_tilde_data_imaging_from(
8892
w_tilde_data=self.w_tilde_data,
89-
data_to_pix_unique=np.array(
90-
mapper.unique_mappings.data_to_pix_unique
91-
),
92-
data_weights=np.array(mapper.unique_mappings.data_weights),
93-
pix_lengths=np.array(mapper.unique_mappings.pix_lengths),
94-
pix_pixels=mapper.params,
93+
rows=rows,
94+
cols=cols,
95+
vals=vals,
96+
S=mapper.total_params,
9597
)
9698
)
9799
param_range = mapper_param_range[mapper_index]
@@ -131,12 +133,14 @@ def _data_vector_x1_mapper(self) -> np.ndarray:
131133
"""
132134
linear_obj = self.linear_obj_list[0]
133135

134-
return inversion_imaging_numba_util.data_vector_via_w_tilde_data_imaging_from(
136+
rows, cols, vals = linear_obj.pixel_triplets
137+
138+
return inversion_imaging_util.data_vector_via_w_tilde_from(
135139
w_tilde_data=self.w_tilde_data,
136-
data_to_pix_unique=linear_obj.unique_mappings.data_to_pix_unique,
137-
data_weights=linear_obj.unique_mappings.data_weights,
138-
pix_lengths=linear_obj.unique_mappings.pix_lengths,
139-
pix_pixels=linear_obj.params,
140+
rows=rows,
141+
cols=cols,
142+
vals=vals,
143+
S=linear_obj.params,
140144
)
141145

142146
@property

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,20 @@ def mapping_matrix(self) -> np.ndarray:
268268
xp=self._xp,
269269
)
270270

271+
@cached_property
272+
def pixel_triplets(self):
273+
274+
rows, cols, vals = mapper_util.pixel_triplets_from_subpixel_arrays_from(
275+
pix_indexes_for_sub=self.pix_indexes_for_sub_slim_index,
276+
pix_weights_for_sub=self.pix_weights_for_sub_slim_index,
277+
slim_index_for_sub=self.slim_index_for_sub_slim_index,
278+
fft_index_for_masked_pixel=self.mapper_grids.mask.fft_index_for_masked_pixel,
279+
sub_fraction_slim=self.over_sampler.sub_fraction.array,
280+
xp=self._xp
281+
)
282+
283+
return rows, cols, vals
284+
271285
def pixel_signals_from(self, signal_scale: float, xp=np) -> np.ndarray:
272286
"""
273287
Returns the signal in each pixelization pixel, where this signal is an estimate of the expected signal
@@ -410,7 +424,6 @@ def extent_from(
410424
extent=self.source_plane_mesh_grid.geometry.extent
411425
)
412426

413-
414427
class PixSubWeights:
415428
def __init__(self, mappings: np.ndarray, sizes: np.ndarray, weights: np.ndarray):
416429
"""

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,102 @@ def adaptive_pixel_signals_from(
442442
return pixel_signals**signal_scale
443443

444444

445+
import numpy as np
446+
447+
def pixel_triplets_from_subpixel_arrays_from(
448+
pix_indexes_for_sub, # (M_sub, P)
449+
pix_weights_for_sub, # (M_sub, P)
450+
slim_index_for_sub, # (M_sub,)
451+
fft_index_for_masked_pixel, # (N_unmasked,)
452+
sub_fraction_slim, # (N_unmasked,)
453+
*,
454+
xp=np,
455+
):
456+
"""
457+
Build sparse source→image mapping triplets (rows, cols, vals)
458+
for a fixed-size interpolation stencil.
459+
460+
This supports both:
461+
- NumPy (xp=np)
462+
- JAX (xp=jax.numpy)
463+
464+
Parameters
465+
----------
466+
pix_indexes_for_sub
467+
Source pixel indices for each subpixel (M_sub, P)
468+
pix_weights_for_sub
469+
Interpolation weights for each subpixel (M_sub, P)
470+
slim_index_for_sub
471+
Mapping subpixel -> slim image pixel index (M_sub,)
472+
fft_index_for_masked_pixel
473+
Mapping slim pixel -> rectangular FFT-grid pixel index (N_unmasked,)
474+
sub_fraction_slim
475+
Oversampling normalization per slim pixel (N_unmasked,)
476+
xp
477+
Backend module (np or jnp)
478+
479+
Returns
480+
-------
481+
rows : (nnz,) int32
482+
Rectangular FFT grid row index per mapping entry
483+
cols : (nnz,) int32
484+
Source pixel index per mapping entry
485+
vals : (nnz,) float64
486+
Mapping weight per entry including sub_fraction normalization
487+
"""
488+
489+
# ------------------------------------------------------------
490+
# Put everything on the right backend
491+
# ------------------------------------------------------------
492+
pix_indexes_for_sub = xp.asarray(pix_indexes_for_sub)
493+
pix_weights_for_sub = xp.asarray(pix_weights_for_sub)
494+
slim_index_for_sub = xp.asarray(slim_index_for_sub)
495+
fft_index_for_masked_pixel = xp.asarray(fft_index_for_masked_pixel)
496+
sub_fraction_slim = xp.asarray(sub_fraction_slim)
497+
498+
# dtypes (important for JAX scatter / indexing performance)
499+
pix_indexes_for_sub = pix_indexes_for_sub.astype(xp.int32)
500+
pix_weights_for_sub = pix_weights_for_sub.astype(xp.float64)
501+
slim_index_for_sub = slim_index_for_sub.astype(xp.int32)
502+
fft_index_for_masked_pixel = fft_index_for_masked_pixel.astype(xp.int32)
503+
sub_fraction_slim = sub_fraction_slim.astype(xp.float64)
504+
505+
# ------------------------------------------------------------
506+
# Dimensions
507+
# ------------------------------------------------------------
508+
M_sub, P = pix_indexes_for_sub.shape
509+
510+
# ------------------------------------------------------------
511+
# Subpixel IDs repeated P times (fixed stencil)
512+
# ------------------------------------------------------------
513+
sub_ids = xp.repeat(xp.arange(M_sub, dtype=xp.int32), P) # (M_sub*P,)
514+
515+
# ------------------------------------------------------------
516+
# Flatten interpolation stencil
517+
# ------------------------------------------------------------
518+
cols = pix_indexes_for_sub.reshape(-1).astype(xp.int32) # (nnz,)
519+
vals = pix_weights_for_sub.reshape(-1).astype(xp.float64) # (nnz,)
520+
521+
# ------------------------------------------------------------
522+
# subpixel -> slim image pixel
523+
# ------------------------------------------------------------
524+
slim_rows = slim_index_for_sub[sub_ids].astype(xp.int32) # (nnz,)
525+
526+
# ------------------------------------------------------------
527+
# slim pixel -> FFT rectangular pixel
528+
# ------------------------------------------------------------
529+
rows = fft_index_for_masked_pixel[slim_rows].astype(xp.int32)
530+
531+
# ------------------------------------------------------------
532+
# Oversampling normalization
533+
# ------------------------------------------------------------
534+
vals = vals * sub_fraction_slim[slim_rows].astype(xp.float64)
535+
536+
return rows, cols, vals
537+
538+
539+
540+
445541
def mapping_matrix_from(
446542
pix_indexes_for_sub_slim_index: np.ndarray,
447543
pix_size_for_sub_slim_index: np.ndarray,

autoarray/mask/mask_2d.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,48 @@ def from_fits(
619619
def shape_native(self) -> Tuple[int, ...]:
620620
return self.shape
621621

622+
@property
623+
def fft_index_for_masked_pixel(self) -> np.ndarray:
624+
"""
625+
Return a mapping from masked-pixel (slim) indices to flat indices
626+
on the rectangular FFT grid.
627+
628+
This array is used to translate between:
629+
630+
- "masked pixel space" (a compact 1D indexing over unmasked pixels)
631+
- the 2D rectangular grid on which FFT-based convolutions are performed
632+
633+
The FFT grid is assumed to be rectangular and already suitable for FFTs
634+
(e.g. padded and centered appropriately). Masked pixels are present on
635+
this grid but are ignored in computations via zero-weighting.
636+
637+
Returns
638+
-------
639+
np.ndarray
640+
A 1D array of shape (N_unmasked,), where element `i` gives the flat
641+
(row-major) index into the FFT grid corresponding to the `i`-th
642+
unmasked pixel in slim ordering.
643+
644+
Notes
645+
-----
646+
- The slim ordering is defined as the order returned by `np.where(~mask)`.
647+
- The flat FFT index is computed assuming row-major (C-style) ordering:
648+
flat_index = y * width + x
649+
- This method is intentionally backend-agnostic and can be used by both
650+
imaging and interferometer curvature pipelines.
651+
"""
652+
# Boolean mask defined on the rectangular FFT grid
653+
mask_fft = self
654+
655+
# Coordinates of unmasked pixels in the FFT grid
656+
ys, xs = np.where(~mask_fft)
657+
658+
# Width of the FFT grid (number of columns)
659+
width = mask_fft.shape[1]
660+
661+
# Convert (y, x) coordinates to flat row-major indices
662+
return (ys * width + xs).astype(np.int32)
663+
622664
def trimmed_array_from(self, padded_array, image_shape) -> Array2D:
623665
"""
624666
Map a padded 1D array of values to its original 2D array, trimming all edge values.

0 commit comments

Comments
 (0)