Skip to content

Commit c161cda

Browse files
authored
Merge pull request #175 from Jammy2211/feature/jax_linear_light
Feature/jax linear light
2 parents c3e4ce3 + 8051e35 commit c161cda

File tree

12 files changed

+152
-230
lines changed

12 files changed

+152
-230
lines changed

autoarray/inversion/inversion/abstract.py

Lines changed: 29 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import jax
23
import jax.numpy as jnp
34
import numpy as np
45
from scipy.linalg import block_diag
@@ -73,17 +74,6 @@ def __init__(
7374
A dictionary which contains timing of certain functions calls which is used for profiling.
7475
"""
7576

76-
# try:
77-
# import numba
78-
# except ModuleNotFoundError:
79-
# raise exc.InversionException(
80-
# "Inversion functionality (linear light profiles, pixelized reconstructions) is "
81-
# "disabled if numba is not installed.\n\n"
82-
# "This is because the run-times without numba are too slow.\n\n"
83-
# "Please install numba, which is described at the following web page:\n\n"
84-
# "https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
85-
# )
86-
8777
self.dataset = dataset
8878

8979
self.linear_obj_list = linear_obj_list
@@ -317,7 +307,7 @@ def operated_mapping_matrix(self) -> np.ndarray:
317307
If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous
318308
linear equations are solved simultaneously.
319309
"""
320-
return np.hstack(self.operated_mapping_matrix_list)
310+
return jnp.hstack(self.operated_mapping_matrix_list)
321311

322312
@cached_property
323313
@profile_func
@@ -474,46 +464,50 @@ def reconstruction(self) -> np.ndarray:
474464
And the data_vector = ZTx, so the corresponding row is also taken out.
475465
"""
476466

477-
if self.settings.force_edge_pixels_to_zeros:
478-
if self.settings.force_edge_image_pixels_to_zeros:
479-
ids_zeros = np.unique(
480-
np.append(
481-
self.mapper_edge_pixel_list, self.mapper_zero_pixel_list
482-
)
483-
)
484-
else:
485-
ids_zeros = self.mapper_edge_pixel_list
467+
if (
468+
self.has(cls=AbstractMapper)
469+
and self.settings.force_edge_pixels_to_zeros
470+
):
486471

487-
values_to_solve = np.ones(
488-
np.shape(self.curvature_reg_matrix)[0], dtype=bool
472+
ids_zeros = jnp.array(self.mapper_edge_pixel_list, dtype=int)
473+
474+
values_to_solve = jnp.ones(
475+
self.curvature_reg_matrix.shape[0], dtype=bool
489476
)
490-
values_to_solve[ids_zeros] = False
477+
values_to_solve = values_to_solve.at[ids_zeros].set(False)
491478

492479
data_vector_input = self.data_vector[values_to_solve]
493480

494481
curvature_reg_matrix_input = self.curvature_reg_matrix[
495482
values_to_solve, :
496483
][:, values_to_solve]
497484

498-
solutions = np.zeros(np.shape(self.curvature_reg_matrix)[0])
499-
500-
solutions[values_to_solve] = (
501-
inversion_util.reconstruction_positive_only_from(
502-
data_vector=data_vector_input,
503-
curvature_reg_matrix=curvature_reg_matrix_input,
504-
settings=self.settings,
505-
)
485+
# Get the values to assign (must be a JAX array)
486+
reconstruction = inversion_util.reconstruction_positive_only_from(
487+
data_vector=data_vector_input,
488+
curvature_reg_matrix=curvature_reg_matrix_input,
489+
settings=self.settings,
506490
)
491+
492+
# Allocate JAX array
493+
solutions = jnp.zeros(self.curvature_reg_matrix.shape[0])
494+
495+
# Get indices where True
496+
indices = jnp.where(values_to_solve)[0]
497+
498+
# Set reconstruction values at those indices
499+
solutions = solutions.at[indices].set(reconstruction)
500+
507501
return solutions
502+
508503
else:
509-
solutions = inversion_util.reconstruction_positive_only_from(
504+
505+
return inversion_util.reconstruction_positive_only_from(
510506
data_vector=self.data_vector,
511507
curvature_reg_matrix=self.curvature_reg_matrix,
512508
settings=self.settings,
513509
)
514510

515-
return solutions
516-
517511
mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
518512

519513
return inversion_util.reconstruction_positive_negative_from(
@@ -522,81 +516,6 @@ def reconstruction(self) -> np.ndarray:
522516
mapper_param_range_list=mapper_param_range_list,
523517
)
524518

525-
# @cached_property
526-
# @profile_func
527-
# def reconstruction(self) -> np.ndarray:
528-
# """
529-
# Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12)
530-
# of https://arxiv.org/pdf/astro-ph/0302587.pdf (Positive-Negative solution)
531-
#
532-
# ============================================================================================
533-
#
534-
# Solve the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf (Non-negative solution)
535-
# Find non-negative solution that minimizes |Z * S - x|^2.
536-
#
537-
# We use fnnls (https://github.com/jvendrow/fnnls) to optimize the quadratic value. Two commonly used
538-
# variables in the code are defined as follows:
539-
# ZTZ := np.dot(Z.T, Z)
540-
# ZTx := np.dot(Z.T, x)
541-
# """
542-
# if self.settings.use_positive_only_solver:
543-
# """
544-
# For the new implementation, we now need to take out the cols and rows of
545-
# the curvature_reg_matrix that corresponds to the parameters we force to be 0.
546-
# Similar for the data vector.
547-
#
548-
# What we actually doing is that we have set the correspoding cols of the Z to be 0.
549-
# As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
550-
# And the data_vector = ZTx, so the corresponding row is also taken out.
551-
# """
552-
#
553-
# if self.settings.force_edge_pixels_to_zeros:
554-
# if self.settings.force_edge_image_pixels_to_zeros:
555-
# ids_zeros = np.unique(
556-
# np.append(
557-
# self.mapper_edge_pixel_list, self.mapper_zero_pixel_list
558-
# )
559-
# )
560-
# else:
561-
# ids_zeros = self.mapper_edge_pixel_list
562-
#
563-
# values_to_solve = np.ones(
564-
# np.shape(self.curvature_reg_matrix)[0], dtype=bool
565-
# )
566-
# values_to_solve[ids_zeros] = False
567-
#
568-
# data_vector_input = self.data_vector[values_to_solve]
569-
#
570-
# curvature_reg_matrix_input = self.curvature_reg_matrix[
571-
# values_to_solve, :
572-
# ][:, values_to_solve]
573-
#
574-
# solutions = inversion_util.reconstruction_positive_only_from(
575-
# data_vector=data_vector_input,
576-
# curvature_reg_matrix=curvature_reg_matrix_input,
577-
# settings=self.settings,
578-
# )
579-
#
580-
# mask = values_to_solve.astype(bool)
581-
#
582-
# return solutions[mask]
583-
# else:
584-
# solutions = inversion_util.reconstruction_positive_only_from(
585-
# data_vector=self.data_vector,
586-
# curvature_reg_matrix=self.curvature_reg_matrix,
587-
# settings=self.settings,
588-
# )
589-
#
590-
# return solutions
591-
#
592-
# mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
593-
#
594-
# return inversion_util.reconstruction_positive_negative_from(
595-
# data_vector=self.data_vector,
596-
# curvature_reg_matrix=self.curvature_reg_matrix,
597-
# mapper_param_range_list=mapper_param_range_list,
598-
# )
599-
600519
@cached_property
601520
@profile_func
602521
def reconstruction_reduced(self) -> np.ndarray:

autoarray/inversion/inversion/imaging/mapping.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def data_vector(self) -> np.ndarray:
109109
110110
The calculation is described in more detail in `inversion_util.data_vector_via_blurred_mapping_matrix_from`.
111111
"""
112-
113112
return inversion_imaging_util.data_vector_via_blurred_mapping_matrix_from(
114113
blurred_mapping_matrix=self.operated_mapping_matrix,
115114
image=self.data.array,

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def _curvature_matrix_func_list_and_mapper(self) -> np.ndarray:
440440
data_weights=mapper.unique_mappings.data_weights,
441441
pix_lengths=mapper.unique_mappings.pix_lengths,
442442
pix_pixels=mapper.params,
443-
curvature_weights=curvature_weights,
443+
curvature_weights=np.array(curvature_weights),
444444
image_frame_1d_lengths=self.convolver.image_frame_1d_lengths,
445445
image_frame_1d_indexes=self.convolver.image_frame_1d_indexes,
446446
image_frame_1d_kernels=self.convolver.image_frame_1d_kernels,

autoarray/inversion/inversion/interferometer/mapping.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def data_vector(self) -> np.ndarray:
7676
"""
7777

7878
return inversion_interferometer_util.data_vector_via_transformed_mapping_matrix_from(
79-
transformed_mapping_matrix=self.operated_mapping_matrix,
79+
transformed_mapping_matrix=np.array(self.operated_mapping_matrix),
8080
visibilities=np.array(self.data),
8181
noise_map=np.array(self.noise_map),
8282
)
@@ -152,8 +152,10 @@ def mapped_reconstructed_data_dict(
152152

153153
visibilities = (
154154
inversion_interferometer_util.mapped_reconstructed_visibilities_from(
155-
transformed_mapping_matrix=operated_mapping_matrix_list[index],
156-
reconstruction=reconstruction,
155+
transformed_mapping_matrix=np.array(
156+
operated_mapping_matrix_list[index]
157+
),
158+
reconstruction=np.array(reconstruction),
157159
)
158160
)
159161

0 commit comments

Comments
 (0)