@@ -533,6 +533,7 @@ def mapping_matrix_native_from(
533533 mask : "Mask2D" ,
534534 blurring_mapping_matrix : Optional [np .ndarray ] = None ,
535535 blurring_mask : Optional ["Mask2D" ] = None ,
536+ use_mixed_precision : bool = False ,
536537 xp = np ,
537538 ) -> np .ndarray :
538539 """
@@ -558,6 +559,10 @@ def mapping_matrix_native_from(
558559 Mask defining the blurring region pixels. Must be provided if
559560 `blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple`
560561 is not already cached.
562+ use_mixed_precision
563+ If True, the mapping matrices are cast to single precision (float32) to
564+ speed up GPU computations and reduce VRAM usage. If False, double precision
565+ (float64) is used for maximum accuracy.
561566
562567 Returns
563568 -------
@@ -566,33 +571,29 @@ def mapping_matrix_native_from(
566571 Contains contributions from both the main mapping matrix and, if provided,
567572 the blurring mapping matrix.
568573 """
574+ dtype_native = xp .float32 if use_mixed_precision else xp .float64
575+
569576 n_src = mapping_matrix .shape [1 ]
570577
571- # Allocate full native grid (ny, nx, n_src )
572- mapping_matrix_native = xp . zeros (
573- mask . shape + ( n_src ,), dtype = mapping_matrix . dtype
574- )
578+ mapping_matrix_native = xp . zeros ( mask . shape + ( n_src ,), dtype = dtype_native )
579+
580+ # Cast inputs to the target dtype to avoid implicit up/downcasts inside scatter
581+ mm = mapping_matrix if mapping_matrix . dtype == dtype_native else xp . asarray ( mapping_matrix , dtype = dtype_native )
575582
576- # Scatter main mapping matrix into native cube
577583 if xp .__name__ .startswith ("jax" ):
578- mapping_matrix_native = mapping_matrix_native .at [
579- mask .slim_to_native_tuple
580- ].set (mapping_matrix )
584+ mapping_matrix_native = mapping_matrix_native .at [mask .slim_to_native_tuple ].set (mm )
581585 else :
582- mapping_matrix_native [mask .slim_to_native_tuple ] = mapping_matrix
583-
584- # Optionally scatter blurring mapping matrix
586+ mapping_matrix_native [mask .slim_to_native_tuple ] = np .asarray (mm )
585587
586588 if blurring_mapping_matrix is not None :
589+ bm = blurring_mapping_matrix
590+ if getattr (bm , "dtype" , None ) != dtype_native :
591+ bm = xp .asarray (bm , dtype = dtype_native )
587592
588593 if xp .__name__ .startswith ("jax" ):
589- mapping_matrix_native = mapping_matrix_native .at [
590- blurring_mask .slim_to_native_tuple
591- ].set (blurring_mapping_matrix )
594+ mapping_matrix_native = mapping_matrix_native .at [blurring_mask .slim_to_native_tuple ].set (bm )
592595 else :
593- mapping_matrix_native [blurring_mask .slim_to_native_tuple ] = (
594- blurring_mapping_matrix
595- )
596+ mapping_matrix_native [blurring_mask .slim_to_native_tuple ] = np .asarray (bm )
596597
597598 return mapping_matrix_native
598599
@@ -730,6 +731,7 @@ def convolved_mapping_matrix_from(
730731 blurring_mapping_matrix = None ,
731732 blurring_mask : Optional [Mask2D ] = None ,
732733 jax_method = "direct" ,
734+ use_mixed_precision : bool = False ,
733735 xp = np ,
734736 ):
735737 """
@@ -770,12 +772,19 @@ def convolved_mapping_matrix_from(
770772 Mapping matrix for the blurring region, outside the mask core.
771773 jax_method : str
772774 Backend passed to real-space convolution if ``use_fft=False``.
775+ use_mixed_precision
776+ If `True`, the FFT is performed using single precision, which provide significant speed up when using a
777+ GPU (x4), reduces VRAM use and is expected to have minimal impact on the accuracy of the results. If `False`,
778+ the FFT is performed using double precision, which is the default and is more accurate but slower on a GPU.
773779
774780 Returns
775781 -------
776782 ndarray of shape (N_pix, N_src)
777783 Convolved mapping matrix in slim form.
778784 """
785+ # -------------------------------------------------------------------------
786+ # NumPy path unchanged
787+ # -------------------------------------------------------------------------
779788 if xp is np :
780789 return self .convolved_mapping_matrix_via_real_space_np_from (
781790 mapping_matrix = mapping_matrix ,
@@ -785,6 +794,9 @@ def convolved_mapping_matrix_from(
785794 xp = xp ,
786795 )
787796
797+ # -------------------------------------------------------------------------
798+ # Non-FFT JAX path unchanged
799+ # -------------------------------------------------------------------------
788800 if not self .use_fft :
789801 return self .convolved_mapping_matrix_via_real_space_from (
790802 mapping_matrix = mapping_matrix ,
@@ -796,34 +808,50 @@ def convolved_mapping_matrix_from(
796808 )
797809
798810 import jax
811+ import jax .numpy as jnp
799812
813+ # -------------------------------------------------------------------------
814+ # Validate cached FFT shapes / state
815+ # -------------------------------------------------------------------------
800816 if self .fft_shape is None :
801-
802817 full_shape , fft_shape , mask_shape = self .fft_shape_from (mask = mask )
803-
804818 raise ValueError (
805819 f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n "
806820 f"Expected mapping matrix padded to match FFT shape of PSF.\n "
807821 f"PSF fft_shape: { fft_shape } , mask shape: { mask .shape } , "
808822 f"mapping_matrix shape: { getattr (mapping_matrix , 'shape' , 'unknown' )} ."
809823 )
810-
811824 else :
812-
813825 fft_shape = self .fft_shape
814826 full_shape = self .full_shape
815827 mask_shape = self .mask_shape
816828 fft_psf_mapping = self .fft_psf_mapping
817829
830+ # -------------------------------------------------------------------------
831+ # Mixed precision dtypes (JAX only)
832+ # -------------------------------------------------------------------------
833+ fft_real_dtype = jnp .float32 if use_mixed_precision else jnp .float64
834+ fft_complex_dtype = jnp .complex64 if use_mixed_precision else jnp .complex128
835+
836+ # Ensure PSF FFT dtype matches the FFT path
837+ fft_psf_mapping = jnp .asarray (fft_psf_mapping , dtype = fft_complex_dtype )
838+
839+ # -------------------------------------------------------------------------
840+ # Build native cube in the FFT dtype (THIS IS THE KEY)
841+ # Requires mapping_matrix_native_from to accept dtype_native kwarg.
842+ # -------------------------------------------------------------------------
818843 mapping_matrix_native = self .mapping_matrix_native_from (
819844 mapping_matrix = mapping_matrix ,
820845 mask = mask ,
821846 blurring_mapping_matrix = blurring_mapping_matrix ,
822847 blurring_mask = blurring_mask ,
848+ use_mixed_precision = use_mixed_precision ,
823849 xp = xp ,
824850 )
825851
852+ # -------------------------------------------------------------------------
826853 # FFT convolution
854+ # -------------------------------------------------------------------------
827855 fft_mapping_matrix_native = xp .fft .rfft2 (
828856 mapping_matrix_native , s = fft_shape , axes = (0 , 1 )
829857 )
@@ -833,7 +861,9 @@ def convolved_mapping_matrix_from(
833861 axes = (0 , 1 ),
834862 )
835863
836- # crop back
864+ # -------------------------------------------------------------------------
865+ # Crop back to mask-shape
866+ # -------------------------------------------------------------------------
837867 start_indices = tuple (
838868 (full_size - out_size ) // 2
839869 for full_size , out_size in zip (full_shape , mask_shape )
@@ -846,8 +876,10 @@ def convolved_mapping_matrix_from(
846876 out_shape_full ,
847877 )
848878
849- # return slim form
850- return blurred_mapping_matrix_native [mask .slim_to_native_tuple ]
879+ # Return slim form
880+ blurred_slim = blurred_mapping_matrix_native [mask .slim_to_native_tuple ]
881+
882+ return blurred_slim
851883
852884 def rescaled_with_odd_dimensions_from (
853885 self , rescale_factor : float , normalize : bool = False
0 commit comments