@@ -612,7 +612,6 @@ def convolved_image_from(
612612 image ,
613613 blurring_image ,
614614 jax_method = "direct" ,
615- use_mixed_precision : bool = False ,
616615 xp = np ,
617616 ):
618617 """
@@ -670,18 +669,14 @@ def convolved_image_from(
670669 import jax .numpy as jnp
671670 from autoarray .structures .arrays .uniform_2d import Array2D
672671
673- # FFT path dtypes (JAX only)
674- fft_real_dtype = jnp .float32 if use_mixed_precision else jnp .float64
675- fft_complex_dtype = jnp .complex64 if use_mixed_precision else jnp .complex128
676-
677672 if self .fft_shape is None :
678673 # Shapes computed on the fly
679674 full_shape , fft_shape , mask_shape = self .fft_shape_from (mask = image .mask )
680675
681676 # Compute PSF FFT on the fly in the chosen precision
682- psf_native = jnp .asarray (self .stored_native .array , dtype = fft_real_dtype )
677+ psf_native = jnp .asarray (self .stored_native .array , dtype = jnp . float64 )
683678 fft_psf = xp .fft .rfft2 (psf_native , s = fft_shape , axes = (0 , 1 )).astype (
684- fft_complex_dtype
679+ jnp . complex128
685680 )
686681
687682 image_shape_original = image .shape_native
@@ -701,19 +696,19 @@ def convolved_image_from(
701696 # Use cached PSF FFT but ensure it matches chosen precision.
702697 # IMPORTANT: casting here may create an extra buffer if self.fft_psf is complex128.
703698 # Best practice is to cache a complex64 version on the object when MP is enabled.
704- fft_psf = jnp .asarray (self .fft_psf , dtype = fft_complex_dtype )
699+ fft_psf = jnp .asarray (self .fft_psf , dtype = jnp . complex128 )
705700
706701 # Build combined native image in the FFT dtype
707- image_both_native = xp .zeros (image .mask .shape , dtype = fft_real_dtype )
702+ image_both_native = xp .zeros (image .mask .shape , dtype = jnp . float64 )
708703
709704 image_both_native = image_both_native .at [image .mask .slim_to_native_tuple ].set (
710- jnp .asarray (image .array , dtype = fft_real_dtype )
705+ jnp .asarray (image .array , dtype = jnp . float64 )
711706 )
712707
713708 if blurring_image is not None :
714709 image_both_native = image_both_native .at [
715710 blurring_image .mask .slim_to_native_tuple
716- ].set (jnp .asarray (blurring_image .array , dtype = fft_real_dtype ))
711+ ].set (jnp .asarray (blurring_image .array , dtype = jnp . float64 ))
717712 else :
718713 warnings .warn (
719714 "No blurring_image provided. Only the direct image will be convolved. "
0 commit comments