1+ from __future__ import annotations
2+
3+ from typing import TYPE_CHECKING
4+
5+ if TYPE_CHECKING :
6+ from autoarray import Mask2D
7+
18import jax
29import jax .numpy as jnp
310import numpy as np
1421from autoarray .structures .grids .uniform_2d import Grid2D
1522from autoarray .structures .header import Header
1623
17- from autoarray import exc
1824from autoarray import type as ty
19- from autoarray .structures .arrays import array_2d_util
2025
2126
2227class Kernel2D (AbstractArray2D ):
@@ -542,6 +547,83 @@ def normalized(self) -> "Kernel2D":
542547 """
543548 return Kernel2D (values = self , mask = self .mask , normalize = True )
544549
550+ def mapping_matrix_native_from (
551+ self ,
552+ mapping_matrix : jnp .ndarray ,
553+ mask : "Mask2D" ,
554+ blurring_mapping_matrix : Optional [jnp .ndarray ] = None ,
555+ blurring_mask : Optional ["Mask2D" ] = None ,
556+ ) -> jnp .ndarray :
557+ """
558+ Expand a slim mapping matrix (image-plane) and optional blurring mapping matrix
559+ into a full native 3D cube (ny, nx, n_src).
560+
561+ This is primarily used for real-space convolution, where the pixel-to-source
562+ mapping must be represented on the full image grid.
563+
564+ Parameters
565+ ----------
566+ mapping_matrix : ndarray (N_pix, N_src)
567+ Slim mapping matrix for unmasked image pixels, mapping each image pixel
568+ to source-plane pixels.
569+ mask : Mask2D
570+ Mask defining which image pixels are unmasked. Used to expand the slim
571+ mapping matrix into a native grid.
572+ blurring_mapping_matrix : ndarray (N_blur, N_src), optional
573+ Mapping matrix for blurring pixels outside the main mask (e.g. light
574+ spilling in from outside). If provided, it is also scattered into the
575+ native cube.
576+ blurring_mask : Mask2D, optional
577+ Mask defining the blurring region pixels. Must be provided if
578+ `blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple`
579+ is not already cached.
580+
581+ Returns
582+ -------
583+ ndarray (ny, nx, N_src)
584+ Native 3D mapping matrix cube with dimensions (image_y, image_x, sources).
585+ Contains contributions from both the main mapping matrix and, if provided,
586+ the blurring mapping matrix.
587+ """
588+ slim_to_native_tuple = self .slim_to_native_tuple
589+ if slim_to_native_tuple is None :
590+ slim_to_native_tuple = jnp .nonzero (
591+ jnp .logical_not (mask .array ), size = mapping_matrix .shape [0 ]
592+ )
593+
594+ n_src = mapping_matrix .shape [1 ]
595+
596+ # Allocate full native grid (ny, nx, n_src)
597+ mapping_matrix_native = jnp .zeros (
598+ mask .shape + (n_src ,), dtype = mapping_matrix .dtype
599+ )
600+
601+ # Scatter main mapping matrix into native cube
602+ mapping_matrix_native = mapping_matrix_native .at [slim_to_native_tuple ].set (
603+ mapping_matrix
604+ )
605+
606+ # Optionally scatter blurring mapping matrix
607+ if blurring_mapping_matrix is not None :
608+ slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
609+
610+ if slim_to_native_blurring_tuple is None :
611+ if blurring_mask is None :
612+ raise ValueError (
613+ "blurring_mask must be provided if blurring_mapping_matrix is given "
614+ "and slim_to_native_blurring_tuple is None."
615+ )
616+ slim_to_native_blurring_tuple = jnp .nonzero (
617+ jnp .logical_not (blurring_mask .array ),
618+ size = blurring_mapping_matrix .shape [0 ],
619+ )
620+
621+ mapping_matrix_native = mapping_matrix_native .at [
622+ slim_to_native_blurring_tuple
623+ ].set (blurring_mapping_matrix )
624+
625+ return mapping_matrix_native
626+
545627 def convolved_image_from (self , image , blurring_image , jax_method = "direct" ):
546628 """
547629 Convolve an input masked image with this PSF.
@@ -665,6 +747,7 @@ def convolved_mapping_matrix_from(
665747 mapping_matrix ,
666748 mask ,
667749 blurring_mapping_matrix = None ,
750+ blurring_mask : Optional [Mask2D ] = None ,
668751 jax_method = "direct" ,
669752 ):
670753 """
@@ -716,6 +799,7 @@ def convolved_mapping_matrix_from(
716799 mapping_matrix = mapping_matrix ,
717800 mask = mask ,
718801 blurring_mapping_matrix = blurring_mapping_matrix ,
802+ blurring_mask = blurring_mask ,
719803 jax_method = jax_method ,
720804 )
721805
@@ -735,35 +819,22 @@ def convolved_mapping_matrix_from(
735819 fft_shape = self .fft_shape
736820 full_shape = self .full_shape
737821 mask_shape = self .mask_shape
738- fft_psf = self .fft_psf
739822 fft_psf_mapping = self .fft_psf_mapping
740823
741824 slim_to_native_tuple = self .slim_to_native_tuple
742825
743826 if slim_to_native_tuple is None :
744827 slim_to_native_tuple = jnp .nonzero (
745- jnp .logical_not (mask .array ), size = mask .shape [0 ]
828+ jnp .logical_not (mask .array ), size = mapping_matrix .shape [0 ]
746829 )
747830
748- n_src = mapping_matrix .shape [1 ]
749-
750- # allocate full native + source dimension
751- mapping_matrix_native = jnp .zeros (
752- mask .shape + (n_src ,), dtype = mapping_matrix .dtype
753- )
754-
755- # scatter main mapping matrix
756- mapping_matrix_native = mapping_matrix_native .at [slim_to_native_tuple ].set (
757- mapping_matrix
831+ mapping_matrix_native = self .mapping_matrix_native_from (
832+ mapping_matrix = mapping_matrix ,
833+ mask = mask ,
834+ blurring_mapping_matrix = blurring_mapping_matrix ,
835+ blurring_mask = blurring_mask ,
758836 )
759837
760- # optionally scatter blurring mapping matrix
761- if blurring_mapping_matrix is not None :
762- slim_to_native_blurring_tuple = self .slim_to_native_blurring_tuple
763- mapping_matrix_native = mapping_matrix_native .at [
764- slim_to_native_blurring_tuple
765- ].set (blurring_mapping_matrix )
766-
767838 # FFT convolution
768839 fft_mapping_matrix_native = jnp .fft .rfft2 (
769840 mapping_matrix_native , s = fft_shape , axes = (0 , 1 )
@@ -960,6 +1031,7 @@ def convolved_mapping_matrix_via_real_space_from(
9601031 mapping_matrix : np .ndarray ,
9611032 mask ,
9621033 blurring_mapping_matrix : Optional [np .ndarray ] = None ,
1034+ blurring_mask : Optional [Mask2D ] = None ,
9631035 jax_method : str = "direct" ,
9641036 ):
9651037 """
@@ -989,60 +1061,25 @@ def convolved_mapping_matrix_via_real_space_from(
9891061 ndarray (N_pix, N_src)
9901062 Convolved mapping matrix in slim form.
9911063 """
992- # 1) Indices of unmasked (image) pixels — no `size=` to avoid wrong lengths
993- ys , xs = self .slim_to_native_tuple or jnp .nonzero (jnp .logical_not (mask .array ))
994- n_pix , n_src = mapping_matrix .shape
995-
996- # Sanity check
997- if ys .shape [0 ] != n_pix :
998- raise ValueError (
999- f"Mapping rows ({ n_pix } ) != unmasked pixels ({ ys .shape [0 ]} ). "
1000- "Make sure you’re using the image (not blurring) index tuple."
1001- )
1002-
1003- # 2) Allocate native cube (ny, nx, n_src)
1004- mapping_matrix_native = jnp .zeros (
1005- mask .shape + (n_src ,), dtype = mapping_matrix .dtype
1006- )
10071064
1008- # 3) Build index grids with identical shape (n_pix, n_src)
1009- ys_exp = jnp .broadcast_to (ys [:, None ], (n_pix , n_src ))
1010- xs_exp = jnp .broadcast_to (xs [:, None ], (n_pix , n_src ))
1011- src_exp = jnp .broadcast_to (jnp .arange (n_src )[None , :], (n_pix , n_src ))
1012-
1013- # 4) Scatter all at once (values also shape (n_pix, n_src))
1014- mapping_matrix_native = mapping_matrix_native .at [(ys_exp , xs_exp , src_exp )].set (
1015- mapping_matrix
1016- )
1065+ slim_to_native_tuple = self .slim_to_native_tuple
10171066
1018- # 5) Optional blurring mapping matrix
1019- if blurring_mapping_matrix is not None :
1020- ys_b , xs_b = self .slim_to_native_blurring_tuple or jnp .nonzero (
1021- jnp .logical_not (
1022- mask .array
1023- ) # use the correct blurring grid mask here if different
1067+ if slim_to_native_tuple is None :
1068+ slim_to_native_tuple = jnp .nonzero (
1069+ jnp .logical_not (mask .array ), size = mapping_matrix .shape [0 ]
10241070 )
1025- n_blur , n_src_b = blurring_mapping_matrix .shape
1026- if n_src_b != n_src :
1027- raise ValueError (
1028- "blurring_mapping_matrix columns must match mapping_matrix columns (n_src)."
1029- )
1030-
1031- ys_b_exp = jnp .broadcast_to (ys_b [:, None ], (n_blur , n_src ))
1032- xs_b_exp = jnp .broadcast_to (xs_b [:, None ], (n_blur , n_src ))
1033- src_b_exp = jnp .broadcast_to (jnp .arange (n_src )[None , :], (n_blur , n_src ))
1034-
1035- mapping_matrix_native = mapping_matrix_native .at [
1036- (ys_b_exp , xs_b_exp , src_b_exp )
1037- ].set (blurring_mapping_matrix )
10381071
1072+ mapping_matrix_native = self .mapping_matrix_native_from (
1073+ mapping_matrix = mapping_matrix ,
1074+ mask = mask ,
1075+ blurring_mapping_matrix = blurring_mapping_matrix ,
1076+ blurring_mask = blurring_mask ,
1077+ )
10391078 # 6) Real-space convolution, broadcast kernel over source axis
10401079 kernel = self .stored_native .array
1041- convolved_native = jax .scipy .signal .convolve (
1080+ blurred_mapping_matrix_native = jax .scipy .signal .convolve (
10421081 mapping_matrix_native , kernel [..., None ], mode = "same" , method = jax_method
10431082 )
10441083
1045- # 7) Pull back to slim (n_pix, n_src)
1046- blurred_mapping_matrix = convolved_native [ys , xs , :]
1047-
1048- return blurred_mapping_matrix
1084+ # return slim form
1085+ return blurred_mapping_matrix_native [slim_to_native_tuple ]
0 commit comments