Skip to content

Commit e15260d

Browse files
Jammy2211Jammy2211
authored andcommitted
couple of tests
1 parent d4d1994 commit e15260d

File tree

5 files changed

+86
-29
lines changed

5 files changed

+86
-29
lines changed

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
from typing import Dict, List, Union, Type
33

4+
from autoconf import cached_property
5+
46
from autoarray.dataset.imaging.dataset import Imaging
57
from autoarray.inversion.inversion.dataset_interface import DatasetInterface
68
from autoarray.inversion.linear_obj.func_list import AbstractLinearObjFuncList
@@ -136,6 +138,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
136138
if linear_func.operated_mapping_matrix_override is not None:
137139
operated_mapping_matrix = linear_func.operated_mapping_matrix_override
138140
else:
141+
vvv
139142
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
140143
mapping_matrix=linear_func.mapping_matrix,
141144
mask=self.mask,
@@ -200,7 +203,7 @@ def data_linear_func_matrix_dict(self):
200203

201204
return data_linear_func_matrix_dict
202205

203-
@property
206+
@cached_property
204207
def mapper_operated_mapping_matrix_dict(self) -> Dict:
205208
"""
206209
The `operated_mapping_matrix` of a `Mapper` object describes the mappings between the observed data's values

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def _data_vector_mapper(self) -> np.ndarray:
7676
param_range = mapper_param_range_list[i]
7777

7878
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
79-
mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp
79+
mapping_matrix=mapper.mapping_matrix,
80+
mask=self.mask,
81+
use_mixed_precision=self.settings.use_mixed_precision,
82+
xp=self._xp
8083
)
8184

8285
data_vector_mapper = (
@@ -135,7 +138,10 @@ def _curvature_matrix_mapper_diag(self) -> Optional[np.ndarray]:
135138
mapper_param_range_i = mapper_param_range_list[i]
136139

137140
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
138-
mapping_matrix=mapper_i.mapping_matrix, mask=self.mask, xp=self._xp
141+
mapping_matrix=mapper_i.mapping_matrix,
142+
mask=self.mask,
143+
use_mixed_precision=self.settings.use_mixed_precision,
144+
xp=self._xp
139145
)
140146

141147
diag = inversion_util.curvature_matrix_via_mapping_matrix_from(

autoarray/inversion/inversion/inversion_util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def curvature_matrix_via_mapping_matrix_from(
8484
no_regularization_index_list: Optional[List] = None,
8585
settings: SettingsInversion = SettingsInversion(),
8686
xp=np,
87+
*,
88+
mp_gemm: bool = True, # mixed precision matmul
89+
gemm_dtype=None, # e.g. xp.float32
90+
out_dtype=None, # e.g. xp.float64
8791
) -> np.ndarray:
8892
"""
8993
Returns the curvature matrix `F` from a blurred mapping matrix `f` and the 1D noise-map $\sigma$
@@ -97,8 +101,13 @@ def curvature_matrix_via_mapping_matrix_from(
97101
noise_map
98102
Flattened 1D array of the noise-map used by the inversion during the fit.
99103
"""
100-
array = mapping_matrix / noise_map[:, None]
101-
curvature_matrix = xp.dot(array.T, array)
104+
if gemm_dtype is None:
105+
gemm_dtype = xp.float32 if (mp_gemm and xp is not np) else mapping_matrix.dtype
106+
107+
# form A in chosen dtype (usually float32 on device)
108+
A = (mapping_matrix / noise_map[:, None]).astype(gemm_dtype)
109+
110+
curvature_matrix = xp.dot(A.T, A) # float32 GEMM if A is float32
102111

103112
if add_to_curvature_diag and len(no_regularization_index_list) > 0:
104113
curvature_matrix = curvature_matrix_with_added_to_diag_from(

autoarray/inversion/inversion/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class SettingsInversion:
1111
def __init__(
1212
self,
13+
use_mixed_precision : bool = False,
1314
use_positive_only_solver: Optional[bool] = None,
1415
positive_only_uses_p_initial: Optional[bool] = None,
1516
use_border_relocator: Optional[bool] = None,
@@ -24,6 +25,12 @@ def __init__(
2425
2526
Parameters
2627
----------
28+
use_mixed_precision
29+
If `True`, the linear algebra calculations of the inversion are performed using single precision on a
30+
targeted subset of functions which provide significant speed up when using a GPU (x4), reduces VRAM
31+
use and are expected to have minimal impact on the accuracy of the results. If `False`, all linear algebra
32+
calculations are performed using double precision, which is the default and is more accurate but
33+
slower on a GPU.
2734
use_positive_only_solver
2835
Whether to use a positive-only linear system solver, which requires that every reconstructed value is
2936
positive but is computationally much slower than the default solver (which allows for positive and

autoarray/structures/arrays/kernel_2d.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def mapping_matrix_native_from(
533533
mask: "Mask2D",
534534
blurring_mapping_matrix: Optional[np.ndarray] = None,
535535
blurring_mask: Optional["Mask2D"] = None,
536+
use_mixed_precision: bool = False,
536537
xp=np,
537538
) -> np.ndarray:
538539
"""
@@ -558,6 +559,10 @@ def mapping_matrix_native_from(
558559
Mask defining the blurring region pixels. Must be provided if
559560
`blurring_mapping_matrix` is given and `slim_to_native_blurring_tuple`
560561
is not already cached.
562+
use_mixed_precision
563+
If True, the mapping matrices are cast to single precision (float32) to
564+
speed up GPU computations and reduce VRAM usage. If False, double precision
565+
(float64) is used for maximum accuracy.
561566
562567
Returns
563568
-------
@@ -566,33 +571,29 @@ def mapping_matrix_native_from(
566571
Contains contributions from both the main mapping matrix and, if provided,
567572
the blurring mapping matrix.
568573
"""
574+
dtype_native = xp.float32 if use_mixed_precision else xp.float64
575+
569576
n_src = mapping_matrix.shape[1]
570577

571-
# Allocate full native grid (ny, nx, n_src)
572-
mapping_matrix_native = xp.zeros(
573-
mask.shape + (n_src,), dtype=mapping_matrix.dtype
574-
)
578+
mapping_matrix_native = xp.zeros(mask.shape + (n_src,), dtype=dtype_native)
579+
580+
# Cast inputs to the target dtype to avoid implicit up/downcasts inside scatter
581+
mm = mapping_matrix if mapping_matrix.dtype == dtype_native else xp.asarray(mapping_matrix, dtype=dtype_native)
575582

576-
# Scatter main mapping matrix into native cube
577583
if xp.__name__.startswith("jax"):
578-
mapping_matrix_native = mapping_matrix_native.at[
579-
mask.slim_to_native_tuple
580-
].set(mapping_matrix)
584+
mapping_matrix_native = mapping_matrix_native.at[mask.slim_to_native_tuple].set(mm)
581585
else:
582-
mapping_matrix_native[mask.slim_to_native_tuple] = mapping_matrix
583-
584-
# Optionally scatter blurring mapping matrix
586+
mapping_matrix_native[mask.slim_to_native_tuple] = np.asarray(mm)
585587

586588
if blurring_mapping_matrix is not None:
589+
bm = blurring_mapping_matrix
590+
if getattr(bm, "dtype", None) != dtype_native:
591+
bm = xp.asarray(bm, dtype=dtype_native)
587592

588593
if xp.__name__.startswith("jax"):
589-
mapping_matrix_native = mapping_matrix_native.at[
590-
blurring_mask.slim_to_native_tuple
591-
].set(blurring_mapping_matrix)
594+
mapping_matrix_native = mapping_matrix_native.at[blurring_mask.slim_to_native_tuple].set(bm)
592595
else:
593-
mapping_matrix_native[blurring_mask.slim_to_native_tuple] = (
594-
blurring_mapping_matrix
595-
)
596+
mapping_matrix_native[blurring_mask.slim_to_native_tuple] = np.asarray(bm)
596597

597598
return mapping_matrix_native
598599

@@ -730,6 +731,7 @@ def convolved_mapping_matrix_from(
730731
blurring_mapping_matrix=None,
731732
blurring_mask: Optional[Mask2D] = None,
732733
jax_method="direct",
734+
use_mixed_precision: bool = False,
733735
xp=np,
734736
):
735737
"""
@@ -770,12 +772,19 @@ def convolved_mapping_matrix_from(
770772
Mapping matrix for the blurring region, outside the mask core.
771773
jax_method : str
772774
Backend passed to real-space convolution if ``use_fft=False``.
775+
use_mixed_precision
776+
If `True`, the FFT is performed using single precision, which provide significant speed up when using a
777+
GPU (x4), reduces VRAM use and is expected to have minimal impact on the accuracy of the results. If `False`,
778+
the FFT is performed using double precision, which is the default and is more accurate but slower on a GPU.
773779
774780
Returns
775781
-------
776782
ndarray of shape (N_pix, N_src)
777783
Convolved mapping matrix in slim form.
778784
"""
785+
# -------------------------------------------------------------------------
786+
# NumPy path unchanged
787+
# -------------------------------------------------------------------------
779788
if xp is np:
780789
return self.convolved_mapping_matrix_via_real_space_np_from(
781790
mapping_matrix=mapping_matrix,
@@ -785,6 +794,9 @@ def convolved_mapping_matrix_from(
785794
xp=xp,
786795
)
787796

797+
# -------------------------------------------------------------------------
798+
# Non-FFT JAX path unchanged
799+
# -------------------------------------------------------------------------
788800
if not self.use_fft:
789801
return self.convolved_mapping_matrix_via_real_space_from(
790802
mapping_matrix=mapping_matrix,
@@ -796,34 +808,50 @@ def convolved_mapping_matrix_from(
796808
)
797809

798810
import jax
811+
import jax.numpy as jnp
799812

813+
# -------------------------------------------------------------------------
814+
# Validate cached FFT shapes / state
815+
# -------------------------------------------------------------------------
800816
if self.fft_shape is None:
801-
802817
full_shape, fft_shape, mask_shape = self.fft_shape_from(mask=mask)
803-
804818
raise ValueError(
805819
f"FFT convolution requires precomputed padded shapes, but `self.fft_shape` is None.\n"
806820
f"Expected mapping matrix padded to match FFT shape of PSF.\n"
807821
f"PSF fft_shape: {fft_shape}, mask shape: {mask.shape}, "
808822
f"mapping_matrix shape: {getattr(mapping_matrix, 'shape', 'unknown')}."
809823
)
810-
811824
else:
812-
813825
fft_shape = self.fft_shape
814826
full_shape = self.full_shape
815827
mask_shape = self.mask_shape
816828
fft_psf_mapping = self.fft_psf_mapping
817829

830+
# -------------------------------------------------------------------------
831+
# Mixed precision dtypes (JAX only)
832+
# -------------------------------------------------------------------------
833+
fft_real_dtype = jnp.float32 if use_mixed_precision else jnp.float64
834+
fft_complex_dtype = jnp.complex64 if use_mixed_precision else jnp.complex128
835+
836+
# Ensure PSF FFT dtype matches the FFT path
837+
fft_psf_mapping = jnp.asarray(fft_psf_mapping, dtype=fft_complex_dtype)
838+
839+
# -------------------------------------------------------------------------
840+
# Build native cube in the FFT dtype (THIS IS THE KEY)
841+
# Requires mapping_matrix_native_from to accept dtype_native kwarg.
842+
# -------------------------------------------------------------------------
818843
mapping_matrix_native = self.mapping_matrix_native_from(
819844
mapping_matrix=mapping_matrix,
820845
mask=mask,
821846
blurring_mapping_matrix=blurring_mapping_matrix,
822847
blurring_mask=blurring_mask,
848+
use_mixed_precision=use_mixed_precision,
823849
xp=xp,
824850
)
825851

852+
# -------------------------------------------------------------------------
826853
# FFT convolution
854+
# -------------------------------------------------------------------------
827855
fft_mapping_matrix_native = xp.fft.rfft2(
828856
mapping_matrix_native, s=fft_shape, axes=(0, 1)
829857
)
@@ -833,7 +861,9 @@ def convolved_mapping_matrix_from(
833861
axes=(0, 1),
834862
)
835863

836-
# crop back
864+
# -------------------------------------------------------------------------
865+
# Crop back to mask-shape
866+
# -------------------------------------------------------------------------
837867
start_indices = tuple(
838868
(full_size - out_size) // 2
839869
for full_size, out_size in zip(full_shape, mask_shape)
@@ -846,8 +876,10 @@ def convolved_mapping_matrix_from(
846876
out_shape_full,
847877
)
848878

849-
# return slim form
850-
return blurred_mapping_matrix_native[mask.slim_to_native_tuple]
879+
# Return slim form
880+
blurred_slim = blurred_mapping_matrix_native[mask.slim_to_native_tuple]
881+
882+
return blurred_slim
851883

852884
def rescaled_with_odd_dimensions_from(
853885
self, rescale_factor: float, normalize: bool = False

0 commit comments

Comments
 (0)