Skip to content

Commit 3863f9c

Browse files
Jammy2211Jammy2211
authored andcommitted
no mixed preciison on image convolution
1 parent c428439 commit 3863f9c

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

autoarray/structures/arrays/kernel_2d.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,6 @@ def convolved_image_from(
612612
image,
613613
blurring_image,
614614
jax_method="direct",
615-
use_mixed_precision: bool = False,
616615
xp=np,
617616
):
618617
"""
@@ -670,18 +669,14 @@ def convolved_image_from(
670669
import jax.numpy as jnp
671670
from autoarray.structures.arrays.uniform_2d import Array2D
672671

673-
# FFT path dtypes (JAX only)
674-
fft_real_dtype = jnp.float32 if use_mixed_precision else jnp.float64
675-
fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128
676-
677672
if self.fft_shape is None:
678673
# Shapes computed on the fly
679674
full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=image.mask)
680675

681676
# Compute PSF FFT on the fly in the chosen precision
682-
psf_native = jnp.asarray(self.stored_native.array, dtype=fft_real_dtype)
677+
psf_native = jnp.asarray(self.stored_native.array, dtype=jnp.float64)
683678
fft_psf = xp.fft.rfft2(psf_native, s=fft_shape, axes=(0, 1)).astype(
684-
fft_complex_dtype
679+
jnp.complex128
685680
)
686681

687682
image_shape_original = image.shape_native
@@ -701,19 +696,19 @@ def convolved_image_from(
701696
# Use cached PSF FFT but ensure it matches chosen precision.
702697
# IMPORTANT: casting here may create an extra buffer if self.fft_psf is complex128.
703698
# Best practice is to cache a complex64 version on the object when MP is enabled.
704-
fft_psf = jnp.asarray(self.fft_psf, dtype=fft_complex_dtype)
699+
fft_psf = jnp.asarray(self.fft_psf, dtype=jnp.complex128)
705700

706701
# Build combined native image in the FFT dtype
707-
image_both_native = xp.zeros(image.mask.shape, dtype=fft_real_dtype)
702+
image_both_native = xp.zeros(image.mask.shape, dtype=jnp.float64)
708703

709704
image_both_native = image_both_native.at[image.mask.slim_to_native_tuple].set(
710-
jnp.asarray(image.array, dtype=fft_real_dtype)
705+
jnp.asarray(image.array, dtype=jnp.float64)
711706
)
712707

713708
if blurring_image is not None:
714709
image_both_native = image_both_native.at[
715710
blurring_image.mask.slim_to_native_tuple
716-
].set(jnp.asarray(blurring_image.array, dtype=fft_real_dtype))
711+
].set(jnp.asarray(blurring_image.array, dtype=jnp.float64))
717712
else:
718713
warnings.warn(
719714
"No blurring_image provided. Only the direct image will be convolved. "

0 commit comments

Comments
 (0)