Skip to content

Commit 92809b2

Browse files
Jammy2211Jammy2211
authored andcommitted
added full use_mixed_precision path
1 parent 411ded5 commit 92809b2

File tree

10 files changed

+132
-64
lines changed

10 files changed

+132
-64
lines changed

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def curvature_matrix(self):
9898
settings=self.settings,
9999
add_to_curvature_diag=True,
100100
no_regularization_index_list=self.no_regularization_index_list,
101+
use_mixed_precision=self.settings.use_mixed_precision,
101102
xp=self._xp,
102103
)
103104

autoarray/inversion/inversion/inversion_util.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,13 @@ def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.nd
7878

7979

8080
def curvature_matrix_via_mapping_matrix_from(
81-
mapping_matrix: np.ndarray,
82-
noise_map: np.ndarray,
81+
mapping_matrix: "np.ndarray",
82+
noise_map: "np.ndarray",
8383
add_to_curvature_diag: bool = False,
8484
no_regularization_index_list: Optional[List] = None,
85-
settings: SettingsInversion = SettingsInversion(),
85+
settings: "SettingsInversion" = SettingsInversion(),
86+
use_mixed_precision: bool = False,
8687
xp=np,
87-
*,
88-
mp_gemm: bool = True, # mixed precision matmul
89-
gemm_dtype=None, # e.g. xp.float32
90-
out_dtype=None, # e.g. xp.float64
9188
) -> np.ndarray:
9289
"""
9390
Returns the curvature matrix `F` from a blurred mapping matrix `f` and the 1D noise-map $\sigma$
@@ -101,15 +98,26 @@ def curvature_matrix_via_mapping_matrix_from(
10198
noise_map
10299
Flattened 1D array of the noise-map used by the inversion during the fit.
103100
"""
104-
if gemm_dtype is None:
105-
gemm_dtype = xp.float32 if (mp_gemm and xp is not np) else mapping_matrix.dtype
106-
107-
# form A in chosen dtype (usually float32 on device)
108-
A = (mapping_matrix / noise_map[:, None]).astype(gemm_dtype)
109-
110-
curvature_matrix = xp.dot(A.T, A) # float32 GEMM if A is float32
111-
112-
if add_to_curvature_diag and len(no_regularization_index_list) > 0:
101+
# NumPy path: keep it simple + stable
102+
if xp is np:
103+
A = mapping_matrix / noise_map[:, None]
104+
curvature_matrix = xp.dot(A.T, A)
105+
else:
106+
# Choose compute dtype
107+
108+
compute_dtype = xp.float32 if use_mixed_precision else xp.float64
109+
out_dtype = xp.float64 # always return float64 for downstream stability
110+
111+
A = mapping_matrix
112+
w = (1.0 / noise_map).astype(compute_dtype)
113+
A = A * w[:, None]
114+
curvature_matrix = xp.dot(A.T, A).astype(out_dtype)
115+
116+
if (
117+
add_to_curvature_diag
118+
and no_regularization_index_list
119+
and len(no_regularization_index_list) > 0
120+
):
113121
curvature_matrix = curvature_matrix_with_added_to_diag_from(
114122
curvature_matrix=curvature_matrix,
115123
value=settings.no_regularization_add_to_curvature_diag_value,

autoarray/inversion/inversion/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class SettingsInversion:
1111
def __init__(
1212
self,
13-
use_mixed_precision : bool = False,
13+
use_mixed_precision: bool = False,
1414
use_positive_only_solver: Optional[bool] = None,
1515
positive_only_uses_p_initial: Optional[bool] = None,
1616
use_border_relocator: Optional[bool] = None,

autoarray/inversion/linear_obj/func_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from autoarray.inversion.linear_obj.neighbors import Neighbors
88
from autoarray.inversion.linear_obj.unique_mappings import UniqueMappings
99
from autoarray.inversion.regularization.abstract import AbstractRegularization
10+
from autoarray.inversion.inversion.settings import SettingsInversion
1011
from autoarray.type import Grid1D2DLike
1112

1213

@@ -15,6 +16,7 @@ def __init__(
1516
self,
1617
grid: Grid1D2DLike,
1718
regularization: Optional[AbstractRegularization],
19+
settings=SettingsInversion(),
1820
xp=np,
1921
):
2022
"""
@@ -45,6 +47,7 @@ def __init__(
4547
super().__init__(regularization=regularization, xp=xp)
4648

4749
self.grid = grid
50+
self.settings = settings
4851

4952
@cached_property
5053
def neighbors(self) -> Neighbors:

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
1212
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
1313
from autoarray.inversion.regularization.abstract import AbstractRegularization
14+
from autoarray.inversion.inversion.settings import SettingsInversion
1415
from autoarray.structures.arrays.uniform_2d import Array2D
1516
from autoarray.structures.grids.uniform_2d import Grid2D
1617
from autoarray.structures.mesh.abstract_2d import Abstract2DMesh
@@ -25,6 +26,7 @@ def __init__(
2526
mapper_grids: MapperGrids,
2627
regularization: Optional[AbstractRegularization],
2728
border_relocator: BorderRelocator,
29+
settings: SettingsInversion = SettingsInversion(),
2830
preloads=None,
2931
xp=np,
3032
):
@@ -90,6 +92,7 @@ def __init__(
9092
self.border_relocator = border_relocator
9193
self.mapper_grids = mapper_grids
9294
self.preloads = preloads
95+
self.settings = settings
9396

9497
@property
9598
def params(self) -> int:
@@ -265,6 +268,7 @@ def mapping_matrix(self) -> np.ndarray:
265268
total_mask_pixels=self.over_sampler.mask.pixels_in_mask,
266269
slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index,
267270
sub_fraction=self.over_sampler.sub_fraction.array,
271+
use_mixed_precision=self.settings.use_mixed_precision,
268272
xp=self._xp,
269273
)
270274

autoarray/inversion/pixelization/mappers/factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
55
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
66
from autoarray.inversion.regularization.abstract import AbstractRegularization
7+
from autoarray.inversion.inversion.settings import SettingsInversion
78
from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular
89
from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
910
from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay
@@ -13,6 +14,7 @@ def mapper_from(
1314
mapper_grids: MapperGrids,
1415
regularization: Optional[AbstractRegularization],
1516
border_relocator: Optional[BorderRelocator] = None,
17+
settings=SettingsInversion(),
1618
preloads=None,
1719
xp=np,
1820
):
@@ -53,13 +55,17 @@ def mapper_from(
5355
mapper_grids=mapper_grids,
5456
border_relocator=border_relocator,
5557
regularization=regularization,
58+
settings=settings,
59+
preloads=preloads,
5660
xp=xp,
5761
)
5862
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular):
5963
return MapperRectangular(
6064
mapper_grids=mapper_grids,
6165
border_relocator=border_relocator,
6266
regularization=regularization,
67+
settings=settings,
68+
preloads=preloads,
6369
xp=xp,
6470
)
6571
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay):

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def mapping_matrix_from(
548548
total_mask_pixels: int,
549549
slim_index_for_sub_slim_index: np.ndarray,
550550
sub_fraction: np.ndarray,
551+
use_mixed_precision: bool = False,
551552
xp=np,
552553
) -> np.ndarray:
553554
"""
@@ -621,39 +622,56 @@ def mapping_matrix_from(
621622
sub_fraction
622623
The fractional area each sub-pixel takes up in an pixel.
623624
"""
625+
624626
M_sub, B = pix_indexes_for_sub_slim_index.shape
625-
M = total_mask_pixels
626-
S = pixels
627+
M = int(total_mask_pixels)
628+
S = int(pixels)
629+
630+
# Indices always int32
631+
pix_idx = xp.asarray(pix_indexes_for_sub_slim_index, dtype=xp.int32)
632+
pix_size = xp.asarray(pix_size_for_sub_slim_index, dtype=xp.int32)
633+
slim_parent = xp.asarray(slim_index_for_sub_slim_index, dtype=xp.int32)
634+
635+
# Everything else computed in float64
636+
w64 = xp.asarray(pix_weights_for_sub_slim_index, dtype=xp.float64)
637+
frac64 = xp.asarray(sub_fraction, dtype=xp.float64)
638+
639+
# Output dtype only (big allocation)
640+
out_dtype = xp.float32 if use_mixed_precision else xp.float64
627641

628642
# 1) Flatten
629-
flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,)
630-
flat_w = pix_weights_for_sub_slim_index.reshape(-1) # (M_sub*B,)
631-
flat_parent = xp.repeat(slim_index_for_sub_slim_index, B) # (M_sub*B,)
632-
flat_count = xp.repeat(pix_size_for_sub_slim_index, B) # (M_sub*B,)
643+
flat_pixidx = pix_idx.reshape(-1) # (M_sub*B,)
644+
flat_w = w64.reshape(-1) # float64
645+
flat_parent = xp.repeat(slim_parent, B) # int32
646+
flat_count = xp.repeat(pix_size, B) # int32
633647

634-
# 2) Build valid mask: k < pix_size[i]
635-
k = xp.tile(xp.arange(B), M_sub) # (M_sub*B,)
636-
valid = k < flat_count # (M_sub*B,)
648+
# 2) valid mask: k < pix_size[i]
649+
k = xp.tile(xp.arange(B, dtype=xp.int32), M_sub)
650+
valid = k < flat_count
637651

638-
# 3) Zero out invalid weights
639-
flat_w = flat_w * valid.astype(flat_w.dtype)
652+
# 3) Zero out invalid weights (float64)
653+
flat_w = flat_w * valid.astype(xp.float64)
640654

641655
# 4) Redirect -1 indices to extra bin S
642656
OUT = S
643657
flat_pixidx = xp.where(flat_pixidx < 0, OUT, flat_pixidx)
644658

645-
# 5) Multiply by sub_fraction of the slim row
646-
flat_frac = xp.take(sub_fraction, flat_parent, axis=0) # (M_sub*B,)
647-
flat_contrib = flat_w * flat_frac # (M_sub*B,)
659+
# 5) Multiply by sub_fraction of the slim row (float64)
660+
flat_frac = xp.take(frac64, flat_parent, axis=0)
661+
flat_contrib64 = flat_w * flat_frac
662+
663+
# 6) Scatter into (M × (S+1)) (destination float32 or float64)
664+
mat = xp.zeros((M, S + 1), dtype=out_dtype)
665+
666+
# Cast only at the write (keeps upstream math float64)
667+
flat_contrib_out = flat_contrib64.astype(out_dtype)
648668

649-
# 6) Scatter into (M × (S+1)), summing duplicates
650-
mat = xp.zeros((M, S + 1), dtype=flat_contrib.dtype)
651669
if xp.__name__.startswith("jax"):
652-
mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib)
670+
mat = mat.at[flat_parent, flat_pixidx].add(flat_contrib_out)
653671
else:
654-
xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib)
672+
xp.add.at(mat, (flat_parent, flat_pixidx), flat_contrib_out)
655673

656-
# 7) Drop the extra column and return
674+
# 7) Drop extra column
657675
return mat[:, :S]
658676

659677

autoarray/operators/mock/mock_psf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ class MockPSF:
55
def __init__(self, operated_mapping_matrix=None):
66
self.operated_mapping_matrix = operated_mapping_matrix
77

8-
def convolved_mapping_matrix_from(self, mapping_matrix, mask, xp=np):
8+
def convolved_mapping_matrix_from(
9+
self, mapping_matrix, mask, use_mixed_precision=False, xp=np
10+
):
911
return self.operated_mapping_matrix

0 commit comments

Comments
 (0)