@@ -548,6 +548,7 @@ def mapping_matrix_from(
548548 total_mask_pixels : int ,
549549 slim_index_for_sub_slim_index : np .ndarray ,
550550 sub_fraction : np .ndarray ,
551+ use_mixed_precision : bool = False ,
551552 xp = np ,
552553) -> np .ndarray :
553554 """
@@ -621,39 +622,56 @@ def mapping_matrix_from(
621622 sub_fraction
622623 The fractional area each sub-pixel takes up in an pixel.
623624 """
625+
624626 M_sub , B = pix_indexes_for_sub_slim_index .shape
625- M = total_mask_pixels
626- S = pixels
627+ M = int (total_mask_pixels )
628+ S = int (pixels )
629+
630+ # Indices always int32
631+ pix_idx = xp .asarray (pix_indexes_for_sub_slim_index , dtype = xp .int32 )
632+ pix_size = xp .asarray (pix_size_for_sub_slim_index , dtype = xp .int32 )
633+ slim_parent = xp .asarray (slim_index_for_sub_slim_index , dtype = xp .int32 )
634+
635+ # Everything else computed in float64
636+ w64 = xp .asarray (pix_weights_for_sub_slim_index , dtype = xp .float64 )
637+ frac64 = xp .asarray (sub_fraction , dtype = xp .float64 )
638+
639+ # Output dtype only (big allocation)
640+ out_dtype = xp .float32 if use_mixed_precision else xp .float64
627641
628642 # 1) Flatten
629- flat_pixidx = pix_indexes_for_sub_slim_index .reshape (- 1 ) # (M_sub*B,)
630- flat_w = pix_weights_for_sub_slim_index .reshape (- 1 ) # (M_sub*B,)
631- flat_parent = xp .repeat (slim_index_for_sub_slim_index , B ) # (M_sub*B,)
632- flat_count = xp .repeat (pix_size_for_sub_slim_index , B ) # (M_sub*B,)
643+ flat_pixidx = pix_idx .reshape (- 1 ) # (M_sub*B,)
644+ flat_w = w64 .reshape (- 1 ) # float64
645+ flat_parent = xp .repeat (slim_parent , B ) # int32
646+ flat_count = xp .repeat (pix_size , B ) # int32
633647
634- # 2) Build valid mask: k < pix_size[i]
635- k = xp .tile (xp .arange (B ), M_sub ) # ( M_sub*B, )
636- valid = k < flat_count # (M_sub*B,)
648+ # 2) valid mask: k < pix_size[i]
649+ k = xp .tile (xp .arange (B , dtype = xp . int32 ), M_sub )
650+ valid = k < flat_count
637651
638- # 3) Zero out invalid weights
639- flat_w = flat_w * valid .astype (flat_w . dtype )
652+ # 3) Zero out invalid weights (float64)
653+ flat_w = flat_w * valid .astype (xp . float64 )
640654
641655 # 4) Redirect -1 indices to extra bin S
642656 OUT = S
643657 flat_pixidx = xp .where (flat_pixidx < 0 , OUT , flat_pixidx )
644658
645- # 5) Multiply by sub_fraction of the slim row
646- flat_frac = xp .take (sub_fraction , flat_parent , axis = 0 ) # (M_sub*B,)
647- flat_contrib = flat_w * flat_frac # (M_sub*B,)
659+ # 5) Multiply by sub_fraction of the slim row (float64)
660+ flat_frac = xp .take (frac64 , flat_parent , axis = 0 )
661+ flat_contrib64 = flat_w * flat_frac
662+
663+ # 6) Scatter into (M × (S+1)) (destination float32 or float64)
664+ mat = xp .zeros ((M , S + 1 ), dtype = out_dtype )
665+
666+ # Cast only at the write (keeps upstream math float64)
667+ flat_contrib_out = flat_contrib64 .astype (out_dtype )
648668
649- # 6) Scatter into (M × (S+1)), summing duplicates
650- mat = xp .zeros ((M , S + 1 ), dtype = flat_contrib .dtype )
651669 if xp .__name__ .startswith ("jax" ):
652- mat = mat .at [flat_parent , flat_pixidx ].add (flat_contrib )
670+ mat = mat .at [flat_parent , flat_pixidx ].add (flat_contrib_out )
653671 else :
654- xp .add .at (mat , (flat_parent , flat_pixidx ), flat_contrib )
672+ xp .add .at (mat , (flat_parent , flat_pixidx ), flat_contrib_out )
655673
656- # 7) Drop the extra column and return
674+ # 7) Drop extra column
657675 return mat [:, :S ]
658676
659677
0 commit comments