Skip to content

Commit 9e8733e

Browse files
Jammy2211Jammy2211
authored andcommitted
store tuple in mask, simplify a lot of code
1 parent a6925d9 commit 9e8733e

File tree

2 files changed

+20
-135
lines changed

2 files changed

+20
-135
lines changed

autoarray/mask/mask_2d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def __init__(
217217
xp=xp,
218218
)
219219

220+
slim_to_native = self.derive_indexes.native_for_slim.astype("int32")
221+
self.slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
222+
220223
@property
221224
def native_for_slim(self):
222225
return self.derive_indexes.native_for_slim

autoarray/structures/arrays/kernel_2d.py

Lines changed: 17 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -122,25 +122,6 @@ def __init__(
122122

123123
self.stored_native = self.native
124124

125-
self.slim_to_native_tuple = None
126-
127-
if image_mask is not None:
128-
129-
slim_to_native = image_mask.derive_indexes.native_for_slim.astype("int32")
130-
self.slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
131-
132-
self.slim_to_native_blurring_tuple = None
133-
134-
if blurring_mask is not None:
135-
136-
slim_to_native_blurring = (
137-
blurring_mask.derive_indexes.native_for_slim.astype("int32")
138-
)
139-
self.slim_to_native_blurring_tuple = (
140-
slim_to_native_blurring[:, 0],
141-
slim_to_native_blurring[:, 1],
142-
)
143-
144125
self.fft_shape = fft_shape
145126

146127
self.mask_shape = None
@@ -585,18 +566,6 @@ def mapping_matrix_native_from(
585566
Contains contributions from both the main mapping matrix and, if provided,
586567
the blurring mapping matrix.
587568
"""
588-
slim_to_native_tuple = self.slim_to_native_tuple
589-
if slim_to_native_tuple is None:
590-
mask_flat = xp.logical_not(mask.array)
591-
592-
if xp.__name__.startswith("jax"):
593-
slim_to_native_tuple = xp.nonzero(
594-
mask_flat, size=mapping_matrix.shape[0]
595-
)
596-
else:
597-
slim_to_native = mask.derive_indexes.native_for_slim.astype("int32")
598-
slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
599-
600569
n_src = mapping_matrix.shape[1]
601570

602571
# Allocate full native grid (ny, nx, n_src)
@@ -606,43 +575,22 @@ def mapping_matrix_native_from(
606575

607576
# Scatter main mapping matrix into native cube
608577
if xp.__name__.startswith("jax"):
609-
mapping_matrix_native = mapping_matrix_native.at[slim_to_native_tuple].set(
578+
mapping_matrix_native = mapping_matrix_native.at[mask.slim_to_native_tuple].set(
610579
mapping_matrix
611580
)
612581
else:
613-
mapping_matrix_native[slim_to_native_tuple] = mapping_matrix
582+
mapping_matrix_native[mask.slim_to_native_tuple] = mapping_matrix
583+
614584
# Optionally scatter blurring mapping matrix
585+
615586
if blurring_mapping_matrix is not None:
616-
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
617-
618-
if slim_to_native_blurring_tuple is None:
619-
if blurring_mask is None:
620-
raise ValueError(
621-
"blurring_mask must be provided if blurring_mapping_matrix is given "
622-
"and slim_to_native_blurring_tuple is None."
623-
)
624-
625-
if xp.__name__.startswith("jax"):
626-
mask_flat = xp.logical_not(blurring_mask.array)
627-
slim_to_native_blurring_tuple = xp.nonzero(
628-
mask_flat,
629-
size=blurring_mapping_matrix.shape[0],
630-
)
631-
else:
632-
slim_to_native_blurring = (
633-
blurring_mask.derive_indexes.native_for_slim.astype("int32")
634-
)
635-
slim_to_native_blurring_tuple = (
636-
slim_to_native_blurring[:, 0],
637-
slim_to_native_blurring[:, 1],
638-
)
639587

640588
if xp.__name__.startswith("jax"):
641589
mapping_matrix_native = mapping_matrix_native.at[
642-
slim_to_native_blurring_tuple
590+
blurring_mask.slim_to_native_tuple
643591
].set(blurring_mapping_matrix)
644592
else:
645-
mapping_matrix_native[slim_to_native_blurring_tuple] = (
593+
mapping_matrix_native[blurring_mask.slim_to_native_tuple] = (
646594
blurring_mapping_matrix
647595
)
648596

@@ -722,31 +670,17 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np
722670
mask_shape = self.mask_shape
723671
fft_psf = self.fft_psf
724672

725-
slim_to_native_tuple = self.slim_to_native_tuple
726-
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
727-
728-
if slim_to_native_tuple is None:
729-
730-
mask_flat = xp.logical_not(image.mask.array)
731-
slim_to_native_tuple = xp.nonzero(mask_flat, size=image.shape[0])
732-
733673
# start with native image padded with zeros
734674
image_both_native = xp.zeros(image.mask.shape, dtype=image.dtype)
735675

736-
image_both_native = image_both_native.at[slim_to_native_tuple].set(
676+
image_both_native = image_both_native.at[image.mask.slim_to_native_tuple].set(
737677
xp.asarray(image.array)
738678
)
739679

740680
# add blurring contribution if provided
741681
if blurring_image is not None:
742-
if slim_to_native_blurring_tuple is None:
743-
744-
mask_flat = xp.logical_not(blurring_image.mask.array)
745-
slim_to_native_blurring_tuple = xp.nonzero(
746-
mask_flat, size=blurring_image.shape[0]
747-
)
748682

749-
image_both_native = image_both_native.at[slim_to_native_blurring_tuple].set(
683+
image_both_native = image_both_native.at[blurring_image.mask.slim_to_native_tuple].set(
750684
xp.asarray(blurring_image.array)
751685
)
752686

@@ -777,7 +711,7 @@ def convolved_image_from(self, image, blurring_image, jax_method="direct", xp=np
777711
)
778712

779713
blurred_image = Array2D(
780-
values=blurred_image_native[slim_to_native_tuple], mask=image.mask
714+
values=blurred_image_native[image.mask.slim_to_native_tuple], mask=image.mask
781715
)
782716

783717
if self.fft_shape is None:
@@ -880,13 +814,6 @@ def convolved_mapping_matrix_from(
880814
mask_shape = self.mask_shape
881815
fft_psf_mapping = self.fft_psf_mapping
882816

883-
slim_to_native_tuple = self.slim_to_native_tuple
884-
885-
if slim_to_native_tuple is None:
886-
887-
mask_flat = xp.logical_not(mask.array)
888-
slim_to_native_tuple = xp.nonzero(mask_flat, size=mapping_matrix.shape[0])
889-
890817
mapping_matrix_native = self.mapping_matrix_native_from(
891818
mapping_matrix=mapping_matrix,
892819
mask=mask,
@@ -1055,30 +982,18 @@ def convolved_image_via_real_space_from(
1055982

1056983
import jax
1057984

1058-
slim_to_native_tuple = self.slim_to_native_tuple
1059-
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
1060-
1061-
if slim_to_native_tuple is None:
1062-
mask_flat = xp.logical_not(image.mask.array)
1063-
slim_to_native_tuple = xp.nonzero(mask_flat, size=image.shape[0])
1064-
1065985
# start with native array padded with zeros
1066986
image_native = xp.zeros(image.mask.shape, dtype=xp.asarray(image.array).dtype)
1067987

1068988
# set image pixels
1069-
image_native = image_native.at[slim_to_native_tuple].set(
989+
image_native = image_native.at[image.mask.slim_to_native_tuple].set(
1070990
xp.asarray(image.array)
1071991
)
1072992

1073993
# add blurring contribution if provided
1074994
if blurring_image is not None:
1075-
if slim_to_native_blurring_tuple is None:
1076995

1077-
slim_to_native_blurring_tuple = xp.nonzero(
1078-
mask_flat,
1079-
size=blurring_image.shape[0],
1080-
)
1081-
image_native = image_native.at[slim_to_native_blurring_tuple].set(
996+
image_native = image_native.at[blurring_image.mask.slim_to_native_tuple].set(
1082997
xp.asarray(blurring_image.array)
1083998
)
1084999
else:
@@ -1094,7 +1009,7 @@ def convolved_image_via_real_space_from(
10941009
image_native, kernel, mode="same", method=jax_method
10951010
)
10961011

1097-
convolved_array_1d = convolve_native[slim_to_native_tuple]
1012+
convolved_array_1d = convolve_native[image.mask.slim_to_native_tuple]
10981013

10991014
return Array2D(values=convolved_array_1d, mask=image.mask)
11001015

@@ -1146,16 +1061,6 @@ def convolved_mapping_matrix_via_real_space_from(
11461061

11471062
import jax
11481063

1149-
slim_to_native_tuple = self.slim_to_native_tuple
1150-
1151-
if slim_to_native_tuple is None:
1152-
1153-
mask_flat = xp.logical_not(mask.array)
1154-
slim_to_native_tuple = xp.nonzero(
1155-
mask_flat,
1156-
size=mapping_matrix.shape[0],
1157-
)
1158-
11591064
mapping_matrix_native = self.mapping_matrix_native_from(
11601065
mapping_matrix=mapping_matrix,
11611066
mask=mask,
@@ -1174,7 +1079,7 @@ def convolved_mapping_matrix_via_real_space_from(
11741079
)
11751080

11761081
# return slim form
1177-
return blurred_mapping_matrix_native[slim_to_native_tuple]
1082+
return blurred_mapping_matrix_native[mask.slim_to_native_tuple]
11781083

11791084
def convolved_image_via_real_space_np_from(
11801085
self, image: np.ndarray, blurring_image: Optional[np.ndarray] = None, xp=np
@@ -1207,32 +1112,16 @@ def convolved_image_via_real_space_np_from(
12071112

12081113
from scipy.signal import convolve as scipy_convolve
12091114

1210-
slim_to_native_tuple = self.slim_to_native_tuple
1211-
slim_to_native_blurring_tuple = self.slim_to_native_blurring_tuple
1212-
1213-
if slim_to_native_tuple is None:
1214-
slim_to_native = image.mask.derive_indexes.native_for_slim.astype("int32")
1215-
slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
1216-
12171115
# start with native array padded with zeros
12181116
image_native = xp.zeros(image.mask.shape, dtype=xp.asarray(image.array).dtype)
12191117

12201118
# set image pixels
1221-
image_native[slim_to_native_tuple] = xp.asarray(image.array)
1119+
image_native[image.mask.slim_to_native_tuple] = xp.asarray(image.array)
12221120

12231121
# add blurring contribution if provided
12241122
if blurring_image is not None:
1225-
if slim_to_native_blurring_tuple is None:
1226-
1227-
slim_to_native_blurring = (
1228-
blurring_image.mask.derive_indexes.native_for_slim.astype("int32")
1229-
)
1230-
slim_to_native_blurring_tuple = (
1231-
slim_to_native_blurring[:, 0],
1232-
slim_to_native_blurring[:, 1],
1233-
)
12341123

1235-
image_native[slim_to_native_blurring_tuple] = xp.asarray(
1124+
image_native[blurring_image.mask.slim_to_native_tuple] = xp.asarray(
12361125
blurring_image.array
12371126
)
12381127
else:
@@ -1248,7 +1137,7 @@ def convolved_image_via_real_space_np_from(
12481137
image_native, kernel, mode="same", method="auto"
12491138
)
12501139

1251-
convolved_array_1d = convolve_native[slim_to_native_tuple]
1140+
convolved_array_1d = convolve_native[image.mask.slim_to_native_tuple]
12521141

12531142
return Array2D(values=convolved_array_1d, mask=image.mask)
12541143

@@ -1290,13 +1179,6 @@ def convolved_mapping_matrix_via_real_space_np_from(
12901179

12911180
from scipy.signal import convolve as scipy_convolve
12921181

1293-
slim_to_native_tuple = self.slim_to_native_tuple
1294-
1295-
if slim_to_native_tuple is None:
1296-
1297-
slim_to_native = mask.derive_indexes.native_for_slim.astype("int32")
1298-
slim_to_native_tuple = (slim_to_native[:, 0], slim_to_native[:, 1])
1299-
13001182
mapping_matrix_native = self.mapping_matrix_native_from(
13011183
mapping_matrix=mapping_matrix,
13021184
mask=mask,
@@ -1314,4 +1196,4 @@ def convolved_mapping_matrix_via_real_space_np_from(
13141196
)
13151197

13161198
# return slim form
1317-
return blurred_mapping_matrix_native[slim_to_native_tuple]
1199+
return blurred_mapping_matrix_native[mask.slim_to_native_tuple]

0 commit comments

Comments
 (0)