Skip to content

Commit 8270894

Browse files
Jammy2211Jammy2211
authored andcommitted
mapping mative native from to avoid repeated code
1 parent 44dbadd commit 8270894

File tree

4 files changed

+107
-78
lines changed

4 files changed

+107
-78
lines changed

autoarray/dataset/imaging/dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import numpy as np
33
from pathlib import Path
4-
import scipy
54
from typing import Optional, Union
65

76
from autoconf import cached_property

autoarray/geometry/geometry_1d.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
from __future__ import annotations
22
import logging
3-
import numpy as np
4-
from typing import TYPE_CHECKING, List, Tuple, Union
5-
6-
if TYPE_CHECKING:
7-
from autoarray.structures.grids.uniform_1d import Grid1D
8-
from autoarray.mask.mask_2d import Mask2D
3+
from typing import Tuple
94

105
from autoarray import type as ty
116

autoarray/mask/derive/mask_2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22
import logging
3-
import copy
43
import numpy as np
54
from typing import TYPE_CHECKING, Tuple
65

@@ -10,7 +9,6 @@
109
from autoarray import exc
1110
from autoarray.mask.derive.indexes_2d import DeriveIndexes2D
1211

13-
from autoarray.structures.arrays import array_2d_util
1412
from autoarray.mask import mask_2d_util
1513

1614
logging.basicConfig()

autoarray/structures/arrays/kernel_2d.py

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from autoarray import Mask2D
7+
18
import jax
29
import jax.numpy as jnp
310
import numpy as np
@@ -14,9 +21,7 @@
1421
from autoarray.structures.grids.uniform_2d import Grid2D
1522
from autoarray.structures.header import Header
1623

17-
from autoarray import exc
1824
from autoarray import type as ty
19-
from autoarray.structures.arrays import array_2d_util
2025

2126

2227
class Kernel2D(AbstractArray2D):
@@ -542,6 +547,83 @@ def normalized(self) -> "Kernel2D":
542547
"""
543548
return Kernel2D(values=self, mask=self.mask, normalize=True)
544549

550+
def mapping_matrix_native_from(
551+
self,
552+
mapping_matrix: jnp.ndarray,
553+
mask: "Mask2D",
554+
blurring_mapping_matrix: Optional[jnp.ndarray] = None,
555+
blurring_mask: Optional["Mask2D"] = None,
556+
) -> jnp.ndarray:
557+
"""
558+
Expand a slim mapping matrix (image-plane) and optional blurring mapping matrix
559+
into a full native 3D cube (ny, nx, n_src).
560+
561+
This is primarily used for real-space convolution, where the pixel-to-source
562+
mapping must be represented on the full image grid.
563+
564+
Parameters
565+
----------
566+
mapping_matrix : ndarray (N_pix, N_src)
567+
Slim mapping matrix for unmasked image pixels, mapping each image pixel
568+
to source-plane pixels.
569+
mask : Mask2D
570+
Mask defining which image pixels are unmasked. Used to expand the slim
571+
mapping matrix into a native grid.
572+
blurring_mapping_matrix : ndarray (N_blur, N_src), optional
573+
Mapping matrix for blurring pixels outside the main mask (e.g. light
574+
spilling in from outside). If provided, it is also scattered into the
575+
native cube.
576+
blurring_mask : Mask2D, optional
577+
Mask defining the blurring region pixels. Must be provided if
578+
`blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple`
579+
is not already cached.
580+
581+
Returns
582+
-------
583+
ndarray (ny, nx, N_src)
584+
Native 3D mapping matrix cube with dimensions (image_y, image_x, sources).
585+
Contains contributions from both the main mapping matrix and, if provided,
586+
the blurring mapping matrix.
587+
"""
588+
slim_to_native_tuple = self.slim_to_native_tuple
589+
if slim_to_native_tuple is None:
590+
slim_to_native_tuple = jnp.nonzero(
591+
jnp.logical_not(mask.array), size=mapping_matrix.shape[0]
592+
)
593+
594+
n_src = mapping_matrix.shape[1]
595+
596+
# Allocate full native grid (ny, nx, n_src)
597+
mapping_matrix_native = jnp.zeros(
598+
mask.shape + (n_src,), dtype=mapping_matrix.dtype
599+
)
600+
601+
# Scatter main mapping matrix into native cube
602+
mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set(
603+
mapping_matrix
604+
)
605+
606+
# Optionally scatter blurring mapping matrix
607+
if blurring_mapping_matrix is not None:
608+
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
609+
610+
if slim_to_native_blurring_tuple is None:
611+
if blurring_mask is None:
612+
raise ValueError(
613+
"blurring_mask must be provided if blurring_mapping_matrix is given "
614+
"and slim_to_native_blurring_tuple is None."
615+
)
616+
slim_to_native_blurring_tuple = jnp.nonzero(
617+
jnp.logical_not(blurring_mask.array),
618+
size=blurring_mapping_matrix.shape[0],
619+
)
620+
621+
mapping_matrix_native = mapping_matrix_native.at[
622+
slim_to_native_blurring_tuple
623+
].set(blurring_mapping_matrix)
624+
625+
return mapping_matrix_native
626+
545627
def convolved_image_from(self, image, blurring_image, jax_method="direct"):
546628
"""
547629
Convolve an input masked image with this PSF.
@@ -665,6 +747,7 @@ def convolved_mapping_matrix_from(
665747
mapping_matrix,
666748
mask,
667749
blurring_mapping_matrix=None,
750+
blurring_mask: Optional[Mask2D] = None,
668751
jax_method="direct",
669752
):
670753
"""
@@ -716,6 +799,7 @@ def convolved_mapping_matrix_from(
716799
mapping_matrix=mapping_matrix,
717800
mask=mask,
718801
blurring_mapping_matrix=blurring_mapping_matrix,
802+
blurring_mask=blurring_mask,
719803
jax_method=jax_method,
720804
)
721805

@@ -735,35 +819,22 @@ def convolved_mapping_matrix_from(
735819
fft_shape = self.fft_shape
736820
full_shape = self.full_shape
737821
mask_shape = self.mask_shape
738-
fft_psf = self.fft_psf
739822
fft_psf_mapping = self.fft_psf_mapping
740823

741824
slim_to_native_tuple = self.slim_to_native_tuple
742825

743826
if slim_to_native_tuple is None:
744827
slim_to_native_tuple = jnp.nonzero(
745-
jnp.logical_not(mask.array), size=mask.shape[0]
828+
jnp.logical_not(mask.array), size=mapping_matrix.shape[0]
746829
)
747830

748-
n_src = mapping_matrix.shape[1]
749-
750-
# allocate full native + source dimension
751-
mapping_matrix_native = jnp.zeros(
752-
mask.shape + (n_src,), dtype=mapping_matrix.dtype
753-
)
754-
755-
# scatter main mapping matrix
756-
mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set(
757-
mapping_matrix
831+
mapping_matrix_native = self.mapping_matrix_native_from(
832+
mapping_matrix=mapping_matrix,
833+
mask=mask,
834+
blurring_mapping_matrix=blurring_mapping_matrix,
835+
blurring_mask=blurring_mask,
758836
)
759837

760-
# optionally scatter blurring mapping matrix
761-
if blurring_mapping_matrix is not None:
762-
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
763-
mapping_matrix_native = mapping_matrix_native.at[
764-
slim_to_native_blurring_tuple
765-
].set(blurring_mapping_matrix)
766-
767838
# FFT convolution
768839
fft_mapping_matrix_native = jnp.fft.rfft2(
769840
mapping_matrix_native, s=fft_shape, axes=(0, 1)
@@ -960,6 +1031,7 @@ def convolved_mapping_matrix_via_real_space_from(
9601031
mapping_matrix: np.ndarray,
9611032
mask,
9621033
blurring_mapping_matrix: Optional[np.ndarray] = None,
1034+
blurring_mask: Optional[Mask2D] = None,
9631035
jax_method: str = "direct",
9641036
):
9651037
"""
@@ -989,60 +1061,25 @@ def convolved_mapping_matrix_via_real_space_from(
9891061
ndarray (N_pix, N_src)
9901062
Convolved mapping matrix in slim form.
9911063
"""
992-
# 1) Indices of unmasked (image) pixels — no `size=` to avoid wrong lengths
993-
ys, xs = self.slim_to_native_tuple or jnp.nonzero(jnp.logical_not(mask.array))
994-
n_pix, n_src = mapping_matrix.shape
995-
996-
# Sanity check
997-
if ys.shape[0] != n_pix:
998-
raise ValueError(
999-
f"Mapping rows ({n_pix}) != unmasked pixels ({ys.shape[0]}). "
1000-
"Make sure you’re using the image (not blurring) index tuple."
1001-
)
1002-
1003-
# 2) Allocate native cube (ny, nx, n_src)
1004-
mapping_matrix_native = jnp.zeros(
1005-
mask.shape + (n_src,), dtype=mapping_matrix.dtype
1006-
)
10071064

1008-
# 3) Build index grids with identical shape (n_pix, n_src)
1009-
ys_exp = jnp.broadcast_to(ys[:, None], (n_pix, n_src))
1010-
xs_exp = jnp.broadcast_to(xs[:, None], (n_pix, n_src))
1011-
src_exp = jnp.broadcast_to(jnp.arange(n_src)[None, :], (n_pix, n_src))
1012-
1013-
# 4) Scatter all at once (values also shape (n_pix, n_src))
1014-
mapping_matrix_native = mapping_matrix_native.at[(ys_exp, xs_exp, src_exp)].set(
1015-
mapping_matrix
1016-
)
1065+
slim_to_native_tuple = self.slim_to_native_tuple
10171066

1018-
# 5) Optional blurring mapping matrix
1019-
if blurring_mapping_matrix is not None:
1020-
ys_b, xs_b = self.slim_to_native_blurring_tuple or jnp.nonzero(
1021-
jnp.logical_not(
1022-
mask.array
1023-
) # use the correct blurring grid mask here if different
1067+
if slim_to_native_tuple is None:
1068+
slim_to_native_tuple = jnp.nonzero(
1069+
jnp.logical_not(mask.array), size=mapping_matrix.shape[0]
10241070
)
1025-
n_blur, n_src_b = blurring_mapping_matrix.shape
1026-
if n_src_b != n_src:
1027-
raise ValueError(
1028-
"blurring_mapping_matrix columns must match mapping_matrix columns (n_src)."
1029-
)
1030-
1031-
ys_b_exp = jnp.broadcast_to(ys_b[:, None], (n_blur, n_src))
1032-
xs_b_exp = jnp.broadcast_to(xs_b[:, None], (n_blur, n_src))
1033-
src_b_exp = jnp.broadcast_to(jnp.arange(n_src)[None, :], (n_blur, n_src))
1034-
1035-
mapping_matrix_native = mapping_matrix_native.at[
1036-
(ys_b_exp, xs_b_exp, src_b_exp)
1037-
].set(blurring_mapping_matrix)
10381071

1072+
mapping_matrix_native = self.mapping_matrix_native_from(
1073+
mapping_matrix=mapping_matrix,
1074+
mask=mask,
1075+
blurring_mapping_matrix=blurring_mapping_matrix,
1076+
blurring_mask=blurring_mask,
1077+
)
10391078
# 6) Real-space convolution, broadcast kernel over source axis
10401079
kernel = self.stored_native.array
1041-
convolved_native = jax.scipy.signal.convolve(
1080+
blurred_mapping_matrix_native = jax.scipy.signal.convolve(
10421081
mapping_matrix_native, kernel[..., None], mode="same", method=jax_method
10431082
)
10441083

1045-
# 7) Pull back to slim (n_pix, n_src)
1046-
blurred_mapping_matrix = convolved_native[ys, xs, :]
1047-
1048-
return blurred_mapping_matrix
1084+
# return slim form
1085+
return blurred_mapping_matrix_native[slim_to_native_tuple]

0 commit comments

Comments
 (0)