Skip to content

Commit 4055b30

Browse files
Jammy2211Jammy2211
authored andcommitted
mixed precision fix
1 parent 60f08e8 commit 4055b30

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

autoarray/structures/arrays/kernel_2d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def convolved_image_from(
612612
image,
613613
blurring_image,
614614
jax_method="direct",
615+
use_mixed_precision : bool = False,
615616
xp=np,
616617
):
617618
"""
@@ -742,6 +743,9 @@ def convolved_image_from(
742743
new_shape=image_shape_original, mask_pad_value=0
743744
)
744745

746+
if use_mixed_precision:
747+
blurred_image = blurred_image.astype(jnp.float32)
748+
745749
return blurred_image
746750

747751
def convolved_mapping_matrix_from(
@@ -1120,6 +1124,7 @@ def convolved_mapping_matrix_via_real_space_from(
11201124
blurring_mask=blurring_mask,
11211125
xp=xp,
11221126
)
1127+
11231128
# 6) Real-space convolution, broadcast kernel over source axis
11241129
kernel = self.stored_native.array
11251130

0 commit comments

Comments
 (0)