Skip to content

Commit e46b087

Browse files
Jammy2211Jammy2211
authored andcommitted
refactor kernel mapping for speed
1 parent ad33991 commit e46b087

File tree

6 files changed

+180
-42
lines changed

6 files changed

+180
-42
lines changed

autoarray/dataset/imaging/dataset.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,26 @@ def __init__(
159159
"""
160160
)
161161

162-
if psf is not None and use_normalized_psf:
162+
if psf is not None:
163+
164+
if not data.mask.is_all_false:
165+
166+
image_mask = data.mask
167+
blurring_mask = data.mask.derive_mask.blurring_from(
168+
kernel_shape_native=psf.shape_native
169+
)
170+
171+
else:
172+
173+
image_mask = None
174+
blurring_mask = None
175+
163176
psf = Kernel2D.no_mask(
164-
values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True
177+
values=psf.native._array,
178+
pixel_scales=psf.pixel_scales,
179+
normalize=use_normalized_psf,
180+
image_mask=image_mask,
181+
blurring_mask=blurring_mask,
165182
)
166183

167184
self.psf = psf

autoarray/inversion/pixelization/mappers/rectangular.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ def pix_sub_weights(self) -> PixSubWeights:
9797
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
9898
are equal to 1.0.
9999
"""
100+
# from autoarray.geometry import geometry_util
101+
#
102+
# mappings = geometry_util.grid_pixel_indexes_2d_slim_from(
103+
# grid_scaled_2d_slim=np.array(self.source_plane_data_grid.over_sampled),
104+
# shape_native=self.source_plane_mesh_grid.shape_native,
105+
# pixel_scales=self.source_plane_mesh_grid.pixel_scales,
106+
# origin=self.source_plane_mesh_grid.origin,
107+
# ).astype("int")
108+
#
109+
# mappings = mappings.reshape((len(mappings), 1))
110+
#
111+
# return PixSubWeights(
112+
# mappings=mappings,
113+
# sizes=np.ones(len(mappings), dtype="int"),
114+
# weights=np.ones(
115+
# (len(self.source_plane_data_grid.over_sampled), 1), dtype="int"
116+
# ),
117+
# )
100118

101119
mappings, weights = (
102120
mapper_util.rectangular_mappings_weights_via_interpolation_from(

autoarray/mask/mask_2d_util.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,14 @@ def native_index_for_slim_index_2d_from(
5050
5151
native_index_for_slim_index_2d = native_index_for_slim_index_2d_from(mask_2d=mask_2d)
5252
"""
53-
if isinstance(mask_2d, np.ndarray):
54-
return np.stack(np.nonzero(~mask_2d.astype(bool))).T
55-
return jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T
53+
54+
if isinstance(mask_2d, jnp.ndarray):
55+
# JAX branch (assume jnp.ndarray)
56+
rows, cols = jnp.where(~mask_2d.astype(bool))
57+
return jnp.stack([rows, cols], axis=1)
58+
59+
rows, cols = np.where(~mask_2d.astype(bool))
60+
return np.stack([rows, cols], axis=1)
5661

5762

5863
def mask_2d_centres_from(

autoarray/plot/wrap/base/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def savefig(self, filename: str, output_path: str, format: str):
108108
plt.savefig(
109109
path.join(output_path, f"{filename}.{format}"),
110110
bbox_inches=self.bbox_inches,
111-
pad_inches=0,
111+
pad_inches=0.1,
112112
)
113113
except ValueError as e:
114114
logger.info(

autoarray/structures/arrays/kernel_2d.py

Lines changed: 115 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(
2424
header=None,
2525
normalize: bool = False,
2626
store_native: bool = False,
27+
image_mask=None,
28+
blurring_mask=None,
2729
*args,
2830
**kwargs,
2931
):
@@ -56,6 +58,25 @@ def __init__(
5658

5759
self.stored_native = self.native
5860

61+
self.slim_to_native_tuple = None
62+
63+
if image_mask is not None:
64+
65+
slim_to_native = image_mask.derive_indexes.native_for_slim.astype("int32")
66+
self.slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
67+
68+
self.slim_to_native_blurring_tuple = None
69+
70+
if blurring_mask is not None:
71+
72+
slim_to_native_blurring = (
73+
blurring_mask.derive_indexes.native_for_slim.astype("int32")
74+
)
75+
self.slim_to_native_blurring_tuple = (
76+
slim_to_native_blurring[:, 0],
77+
slim_to_native_blurring[:, 1],
78+
)
79+
5980
@classmethod
6081
def no_mask(
6182
cls,
@@ -64,6 +85,8 @@ def no_mask(
6485
shape_native: Tuple[int, int] = None,
6586
origin: Tuple[float, float] = (0.0, 0.0),
6687
normalize: bool = False,
88+
image_mask=None,
89+
blurring_mask=None,
6790
):
6891
"""
6992
Create a Kernel2D (see *Kernel2D.__new__*) by inputting the kernel values in 1D or 2D, automatically
@@ -91,7 +114,13 @@ def no_mask(
91114
pixel_scales=pixel_scales,
92115
origin=origin,
93116
)
94-
return Kernel2D(values=values, mask=values.mask, normalize=normalize)
117+
return Kernel2D(
118+
values=values,
119+
mask=values.mask,
120+
normalize=normalize,
121+
image_mask=image_mask,
122+
blurring_mask=blurring_mask,
123+
)
95124

96125
@classmethod
97126
def full(
@@ -540,29 +569,41 @@ def convolve_image(self, image, blurring_image, jax_method="direct"):
540569
kernels that are more than about 5x5. Default is `fft`.
541570
"""
542571

543-
slim_to_native = jnp.nonzero(
544-
jnp.logical_not(image.mask.array), size=image.shape[0]
545-
)
546-
slim_to_native_blurring = jnp.nonzero(
547-
jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0]
548-
)
572+
slim_to_native_tuple = self.slim_to_native_tuple
573+
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
549574

550-
expanded_array_native = jnp.zeros(image.mask.shape)
575+
if slim_to_native_tuple is None:
551576

552-
expanded_array_native = expanded_array_native.at[slim_to_native].set(
553-
image.array
577+
slim_to_native_tuple = jnp.nonzero(
578+
jnp.logical_not(image.mask.array), size=image.shape[0]
579+
)
580+
581+
if slim_to_native_blurring_tuple is None:
582+
583+
slim_to_native_blurring_tuple = jnp.nonzero(
584+
jnp.logical_not(blurring_image.mask.array), size=blurring_image.shape[0]
585+
)
586+
587+
# make sure dtype matches what you want
588+
expanded_array_native = jnp.zeros(
589+
image.mask.shape, dtype=jnp.asarray(image.array).dtype
554590
)
555-
expanded_array_native = expanded_array_native.at[slim_to_native_blurring].set(
556-
blurring_image.array
591+
592+
# set using a tuple of index arrays
593+
expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set(
594+
jnp.asarray(image.array)
557595
)
596+
expanded_array_native = expanded_array_native.at[
597+
slim_to_native_blurring_tuple
598+
].set(jnp.asarray(blurring_image.array))
558599

559600
kernel = self.stored_native.array
560601

561602
convolve_native = jax.scipy.signal.convolve(
562603
expanded_array_native, kernel, mode="same", method=jax_method
563604
)
564605

565-
convolved_array_1d = convolve_native[slim_to_native]
606+
convolved_array_1d = convolve_native[slim_to_native_tuple]
566607

567608
return Array2D(values=convolved_array_1d, mask=image.mask)
568609

@@ -583,24 +624,77 @@ def convolve_image_no_blurring(self, image, mask, jax_method="direct"):
583624
kernels that are more than about 5x5. Default is `fft`.
584625
"""
585626

586-
slim_to_native = jnp.nonzero(jnp.logical_not(mask.array), size=image.shape[0])
627+
slim_to_native_tuple = self.slim_to_native_tuple
628+
629+
if slim_to_native_tuple is None:
630+
631+
slim_to_native_tuple = jnp.nonzero(
632+
jnp.logical_not(mask.array), size=image.shape[0]
633+
)
587634

635+
# make sure dtype matches what you want
588636
expanded_array_native = jnp.zeros(mask.shape)
589637

638+
# set using a tuple of index arrays
590639
if isinstance(image, np.ndarray) or isinstance(image, jnp.ndarray):
591-
expanded_array_native = expanded_array_native.at[slim_to_native].set(image)
640+
expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set(
641+
image
642+
)
592643
else:
593-
expanded_array_native = expanded_array_native.at[slim_to_native].set(
594-
image.array
644+
expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set(
645+
jnp.asarray(image.array)
646+
)
647+
648+
kernel = self.stored_native.array
649+
650+
convolve_native = jax.scipy.signal.convolve(
651+
expanded_array_native, kernel, mode="same", method=jax_method
652+
)
653+
654+
convolved_array_1d = convolve_native[slim_to_native_tuple]
655+
656+
return Array2D(values=convolved_array_1d, mask=mask)
657+
658+
def convolve_image_no_blurring_for_mapping(self, image, mask, jax_method="direct"):
659+
"""
660+
For a given 1D array and blurring array, convolve the two using this psf.
661+
662+
Parameters
663+
----------
664+
image
665+
1D array of the values which are to be blurred with the psf's PSF.
666+
blurring_image
667+
1D array of the blurring values which blur into the array after PSF convolution.
668+
jax_method
669+
If JAX is enabled this keyword will indicate what method is used for the PSF
670+
convolution. Can be either `direct` to calculate it in real space or `fft`
671+
to calculated it via a fast Fourier transform. `fft` is typically faster for
672+
kernels that are more than about 5x5. Default is `fft`.
673+
"""
674+
675+
slim_to_native_tuple = self.slim_to_native_tuple
676+
677+
if slim_to_native_tuple is None:
678+
679+
slim_to_native_tuple = jnp.nonzero(
680+
jnp.logical_not(mask.array), size=image.shape[0]
595681
)
596682

683+
# make sure dtype matches what you want
684+
expanded_array_native = jnp.zeros(mask.shape)
685+
686+
# set using a tuple of index arrays
687+
expanded_array_native = expanded_array_native.at[slim_to_native_tuple].set(
688+
image
689+
)
690+
597691
kernel = self.stored_native.array
598692

599693
convolve_native = jax.scipy.signal.convolve(
600694
expanded_array_native, kernel, mode="same", method=jax_method
601695
)
602696

603-
convolved_array_1d = convolve_native[slim_to_native]
697+
convolved_array_1d = convolve_native[slim_to_native_tuple]
604698

605699
return Array2D(values=convolved_array_1d, mask=mask)
606700

@@ -612,6 +706,6 @@ def convolve_mapping_matrix(self, mapping_matrix, mask, jax_method="direct"):
612706
image
613707
1D array of the values which are to be blurred with the psf's PSF.
614708
"""
615-
return jax.vmap(self.convolve_image_no_blurring, in_axes=(1, None, None))(
616-
mapping_matrix, mask, jax_method
617-
).T
709+
return jax.vmap(
710+
self.convolve_image_no_blurring_for_mapping, in_axes=(1, None, None)
711+
)(mapping_matrix, mask, jax_method).T

autoarray/structures/grids/grid_2d_util.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -253,25 +253,29 @@ def grid_2d_slim_via_mask_from(
253253
centres_scaled = geometry_util.central_scaled_coordinate_2d_from(
254254
shape_native=mask_2d.shape, pixel_scales=pixel_scales, origin=origin
255255
)
256-
if isinstance(mask_2d, jnp.ndarray):
257256

258-
centres_scaled = jnp.array(centres_scaled)
259-
pixel_scales = jnp.array(pixel_scales)
257+
# JAX branch
258+
if isinstance(mask_2d, jnp.ndarray):
259+
centres_scaled = jnp.asarray(centres_scaled)
260+
pixel_scales = jnp.asarray(pixel_scales)
260261
sign = jnp.array([-1.0, 1.0])
261-
return (
262-
(jnp.stack(jnp.nonzero(~mask_2d.astype(bool))).T - centres_scaled)
263-
* sign
264-
* pixel_scales
265-
)
266262

267-
centres_scaled = np.array(centres_scaled)
268-
pixel_scales = np.array(pixel_scales)
263+
# use jnp.where instead of jnp.nonzero
264+
rows, cols = jnp.where(~mask_2d.astype(bool))
265+
indices = jnp.stack([rows, cols], axis=1) # shape (N_unmasked, 2)
266+
267+
# (indices - centre) -> pixel offsets; apply sign and scale to get physical coords
268+
return (indices - centres_scaled) * sign * pixel_scales
269+
270+
# NumPy branch (kept consistent)
271+
centres_scaled = np.asarray(centres_scaled)
272+
pixel_scales = np.asarray(pixel_scales)
269273
sign = np.array([-1.0, 1.0])
270-
return (
271-
(np.stack(np.nonzero(~mask_2d.astype(bool))).T - centres_scaled)
272-
* sign
273-
* pixel_scales
274-
)
274+
275+
rows, cols = np.where(~mask_2d.astype(bool))
276+
indices = np.stack([rows, cols], axis=1)
277+
278+
return (indices - centres_scaled) * sign * pixel_scales
275279

276280

277281
def grid_2d_via_mask_from(

0 commit comments

Comments
 (0)