Skip to content

Commit 7f921f1

Browse files
authored
Merge pull request #205 from Jammy2211/feature/linalg_mixed_precision
Feature/linalg mixed precision
2 parents be670fa + a78b864 commit 7f921f1

File tree

14 files changed

+190
-156
lines changed

14 files changed

+190
-156
lines changed

autoarray/dataset/imaging/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,10 @@ def apply_sparse_operator(
504504
Whether to use JAX to compute W-Tilde. This requires JAX to be installed.
505505
"""
506506

507+
logger.info(
508+
"IMAGING - Setting Up Sparse Operator For low Memory Pixelizations."
509+
)
510+
507511
sparse_operator = (
508512
inversion_imaging_util.ImagingSparseOperator.from_noise_map_and_psf(
509513
data=self.data,

autoarray/dataset/interferometer/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def apply_sparse_operator(
198198
if nufft_precision_operator is None:
199199

200200
logger.info(
201-
"INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, CPU run times may exceed hours."
201+
"INTERFEROMETER - Computing NUFFT Precision Operator; runtime scales with visibility count and mask resolution, CPU run times may exceed hours."
202202
)
203203

204204
nufft_precision_operator = self.psf_precision_operator_from(

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
9595
self.psf.convolved_mapping_matrix_from(
9696
mapping_matrix=linear_obj.mapping_matrix,
9797
mask=self.mask,
98+
use_mixed_precision=self.settings.use_mixed_precision,
9899
xp=self._xp,
99100
)
100101
if linear_obj.operated_mapping_matrix_override is None

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -53,44 +53,6 @@ def __init__(
5353
xp=xp,
5454
)
5555

56-
@property
57-
def _data_vector_mapper(self) -> np.ndarray:
58-
"""
59-
Returns the `data_vector` of all mappers, a 1D vector whose values are solved for by the simultaneous
60-
linear equations constructed by this object. The object is described in full in the method `data_vector`.
61-
62-
This method is used to compute part of the `data_vector` if there are also linear function list objects
63-
in the inversion, and is separated into a separate method to enable preloading of the mapper `data_vector`.
64-
"""
65-
66-
if not self.has(cls=AbstractMapper):
67-
return
68-
69-
data_vector = np.zeros(self.total_params)
70-
71-
mapper_list = self.cls_list_from(cls=AbstractMapper)
72-
mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
73-
74-
for i in range(len(mapper_list)):
75-
mapper = mapper_list[i]
76-
param_range = mapper_param_range_list[i]
77-
78-
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
79-
mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp
80-
)
81-
82-
data_vector_mapper = (
83-
inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from(
84-
blurred_mapping_matrix=operated_mapping_matrix,
85-
image=self.data,
86-
noise_map=self.noise_map,
87-
)
88-
)
89-
90-
data_vector[param_range[0] : param_range[1],] = data_vector_mapper
91-
92-
return data_vector
93-
9456
@cached_property
9557
def data_vector(self) -> np.ndarray:
9658
"""
@@ -111,53 +73,6 @@ def data_vector(self) -> np.ndarray:
11173
noise_map=self.noise_map.array,
11274
)
11375

114-
@property
115-
def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
116-
"""
117-
Returns the diagonal regions of the `curvature_matrix`, a 2D matrix which uses the mappings between the data
118-
and the linear objects to construct the simultaneous linear equations. The object is described in full in
119-
the method `curvature_matrix`.
120-
121-
This method computes the diagonal entries of all mapper objects in the `curvature_matrix`. It is separate from
122-
other calculations to enable preloading of this calculation.
123-
"""
124-
125-
if not self.has(cls=AbstractMapper):
126-
return None
127-
128-
curvature_matrix = np.zeros((self.total_params, self.total_params))
129-
130-
mapper_list = self.cls_list_from(cls=AbstractMapper)
131-
mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
132-
133-
for i in range(len(mapper_list)):
134-
mapper_i = mapper_list[i]
135-
mapper_param_range_i = mapper_param_range_list[i]
136-
137-
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
138-
mapping_matrix=mapper_i.mapping_matrix, mask=self.mask, xp=self._xp
139-
)
140-
141-
diag = inversion_util.curvature_matrix_via_mapping_matrix_from(
142-
mapping_matrix=operated_mapping_matrix,
143-
noise_map=self.noise_map,
144-
settings=self.settings,
145-
add_to_curvature_diag=True,
146-
no_regularization_index_list=self.no_regularization_index_list,
147-
xp=self._xp,
148-
)
149-
150-
curvature_matrix[
151-
mapper_param_range_i[0] : mapper_param_range_i[1],
152-
mapper_param_range_i[0] : mapper_param_range_i[1],
153-
] = diag
154-
155-
curvature_matrix = inversion_util.curvature_matrix_mirrored_from(
156-
curvature_matrix=curvature_matrix, xp=self._xp
157-
)
158-
159-
return curvature_matrix
160-
16176
@cached_property
16277
def curvature_matrix(self):
16378
"""

autoarray/inversion/inversion/interferometer/mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ def curvature_matrix(self) -> np.ndarray:
8888
real_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from(
8989
mapping_matrix=self.operated_mapping_matrix.real,
9090
noise_map=self.noise_map.real,
91+
settings=self.settings,
9192
xp=self._xp,
9293
)
9394

9495
imag_curvature_matrix = inversion_util.curvature_matrix_via_mapping_matrix_from(
9596
mapping_matrix=self.operated_mapping_matrix.imag,
9697
noise_map=self.noise_map.imag,
98+
settings=self.settings,
9799
xp=self._xp,
98100
)
99101

autoarray/inversion/inversion/inversion_util.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def curvature_matrix_mirrored_from(curvature_matrix: np.ndarray, xp=np) -> np.nd
7878

7979

8080
def curvature_matrix_via_mapping_matrix_from(
81-
mapping_matrix: np.ndarray,
82-
noise_map: np.ndarray,
81+
mapping_matrix: "np.ndarray",
82+
noise_map: "np.ndarray",
8383
add_to_curvature_diag: bool = False,
8484
no_regularization_index_list: Optional[List] = None,
85-
settings: SettingsInversion = SettingsInversion(),
85+
settings: "SettingsInversion" = SettingsInversion(),
8686
xp=np,
8787
) -> np.ndarray:
8888
"""
@@ -97,10 +97,26 @@ def curvature_matrix_via_mapping_matrix_from(
9797
noise_map
9898
Flattened 1D array of the noise-map used by the inversion during the fit.
9999
"""
100-
array = mapping_matrix / noise_map[:, None]
101-
curvature_matrix = xp.dot(array.T, array)
102-
103-
if add_to_curvature_diag and len(no_regularization_index_list) > 0:
100+
# NumPy path: keep it simple + stable
101+
if xp is np:
102+
A = mapping_matrix / noise_map[:, None]
103+
curvature_matrix = xp.dot(A.T, A)
104+
else:
105+
# Choose compute dtype
106+
107+
compute_dtype = xp.float32 if settings.use_mixed_precision else xp.float64
108+
out_dtype = xp.float64 # always return float64 for downstream stability
109+
110+
A = mapping_matrix
111+
w = (1.0 / noise_map).astype(compute_dtype)
112+
A = A * w[:, None]
113+
curvature_matrix = xp.dot(A.T, A).astype(out_dtype)
114+
115+
if (
116+
add_to_curvature_diag
117+
and no_regularization_index_list
118+
and len(no_regularization_index_list) > 0
119+
):
104120
curvature_matrix = curvature_matrix_with_added_to_diag_from(
105121
curvature_matrix=curvature_matrix,
106122
value=settings.no_regularization_add_to_curvature_diag_value,

autoarray/inversion/inversion/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class SettingsInversion:
1111
def __init__(
1212
self,
13+
use_mixed_precision: bool = False,
1314
use_positive_only_solver: Optional[bool] = None,
1415
positive_only_uses_p_initial: Optional[bool] = None,
1516
use_border_relocator: Optional[bool] = None,
@@ -24,6 +25,12 @@ def __init__(
2425
2526
Parameters
2627
----------
28+
use_mixed_precision
29+
If `True`, the linear algebra calculations of the inversion are performed using single precision on a
30+
targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM
31+
use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra
32+
calculations are performed using double precision, which is the default and is more accurate but
33+
slower on a GPU.
2734
use_positive_only_solver
2835
Whether to use a positive-only linear system solver, which requires that every reconstructed value is
2936
positive but is computationally much slower than the default solver (which allows for positive and
@@ -41,6 +48,7 @@ def __init__(
4148
For an interferometer inversion using the linear operators method, sets the maximum number of iterations
4249
of the solver (this input does nothing for dataset data and other interferometer methods).
4350
"""
51+
self.use_mixed_precision = use_mixed_precision
4452
self._use_positive_only_solver = use_positive_only_solver
4553
self._positive_only_uses_p_initial = positive_only_uses_p_initial
4654
self._use_border_relocator = use_border_relocator

autoarray/inversion/linear_obj/func_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from autoarray.inversion.linear_obj.neighbors import Neighbors
88
from autoarray.inversion.linear_obj.unique_mappings import UniqueMappings
99
from autoarray.inversion.regularization.abstract import AbstractRegularization
10+
from autoarray.inversion.inversion.settings import SettingsInversion
1011
from autoarray.type import Grid1D2DLike
1112

1213

@@ -15,6 +16,7 @@ def __init__(
1516
self,
1617
grid: Grid1D2DLike,
1718
regularization: Optional[AbstractRegularization],
19+
settings=SettingsInversion(),
1820
xp=np,
1921
):
2022
"""
@@ -45,6 +47,7 @@ def __init__(
4547
super().__init__(regularization=regularization, xp=xp)
4648

4749
self.grid = grid
50+
self.settings = settings
4851

4952
@cached_property
5053
def neighbors(self) -> Neighbors:

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
1212
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
1313
from autoarray.inversion.regularization.abstract import AbstractRegularization
14+
from autoarray.inversion.inversion.settings import SettingsInversion
1415
from autoarray.structures.arrays.uniform_2d import Array2D
1516
from autoarray.structures.grids.uniform_2d import Grid2D
1617
from autoarray.structures.mesh.abstract_2d import Abstract2DMesh
@@ -25,6 +26,7 @@ def __init__(
2526
mapper_grids: MapperGrids,
2627
regularization: Optional[AbstractRegularization],
2728
border_relocator: BorderRelocator,
29+
settings: SettingsInversion = SettingsInversion(),
2830
preloads=None,
2931
xp=np,
3032
):
@@ -90,6 +92,7 @@ def __init__(
9092
self.border_relocator = border_relocator
9193
self.mapper_grids = mapper_grids
9294
self.preloads = preloads
95+
self.settings = settings
9396

9497
@property
9598
def params(self) -> int:
@@ -265,6 +268,7 @@ def mapping_matrix(self) -> np.ndarray:
265268
total_mask_pixels=self.over_sampler.mask.pixels_in_mask,
266269
slim_index_for_sub_slim_index=self.slim_index_for_sub_slim_index,
267270
sub_fraction=self.over_sampler.sub_fraction.array,
271+
use_mixed_precision=self.settings.use_mixed_precision,
268272
xp=self._xp,
269273
)
270274

autoarray/inversion/pixelization/mappers/factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
55
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
66
from autoarray.inversion.regularization.abstract import AbstractRegularization
7+
from autoarray.inversion.inversion.settings import SettingsInversion
78
from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular
89
from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
910
from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay
@@ -13,6 +14,7 @@ def mapper_from(
1314
mapper_grids: MapperGrids,
1415
regularization: Optional[AbstractRegularization],
1516
border_relocator: Optional[BorderRelocator] = None,
17+
settings=SettingsInversion(),
1618
preloads=None,
1719
xp=np,
1820
):
@@ -53,20 +55,25 @@ def mapper_from(
5355
mapper_grids=mapper_grids,
5456
border_relocator=border_relocator,
5557
regularization=regularization,
58+
settings=settings,
59+
preloads=preloads,
5660
xp=xp,
5761
)
5862
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular):
5963
return MapperRectangular(
6064
mapper_grids=mapper_grids,
6165
border_relocator=border_relocator,
6266
regularization=regularization,
67+
settings=settings,
68+
preloads=preloads,
6369
xp=xp,
6470
)
6571
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay):
6672
return MapperDelaunay(
6773
mapper_grids=mapper_grids,
6874
border_relocator=border_relocator,
6975
regularization=regularization,
76+
settings=settings,
7077
preloads=preloads,
7178
xp=xp,
7279
)

0 commit comments

Comments
 (0)