11import copy
2+ import jax
23import jax .numpy as jnp
34import numpy as np
45from scipy .linalg import block_diag
@@ -73,17 +74,6 @@ def __init__(
7374 A dictionary which contains timing of certain functions calls which is used for profiling.
7475 """
7576
76- # try:
77- # import numba
78- # except ModuleNotFoundError:
79- # raise exc.InversionException(
80- # "Inversion functionality (linear light profiles, pixelized reconstructions) is "
81- # "disabled if numba is not installed.\n\n"
82- # "This is because the run-times without numba are too slow.\n\n"
83- # "Please install numba, which is described at the following web page:\n\n"
84- # "https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
85- # )
86-
8777 self .dataset = dataset
8878
8979 self .linear_obj_list = linear_obj_list
@@ -317,7 +307,7 @@ def operated_mapping_matrix(self) -> np.ndarray:
317307 If there are multiple linear objects, the blurred mapping matrices are stacked such that their simultaneous
318308 linear equations are solved simultaneously.
319309 """
320- return np .hstack (self .operated_mapping_matrix_list )
310+ return jnp .hstack (self .operated_mapping_matrix_list )
321311
322312 @cached_property
323313 @profile_func
@@ -474,46 +464,50 @@ def reconstruction(self) -> np.ndarray:
474464 And the data_vector = ZTx, so the corresponding row is also taken out.
475465 """
476466
477- if self .settings .force_edge_pixels_to_zeros :
478- if self .settings .force_edge_image_pixels_to_zeros :
479- ids_zeros = np .unique (
480- np .append (
481- self .mapper_edge_pixel_list , self .mapper_zero_pixel_list
482- )
483- )
484- else :
485- ids_zeros = self .mapper_edge_pixel_list
467+ if (
468+ self .has (cls = AbstractMapper )
469+ and self .settings .force_edge_pixels_to_zeros
470+ ):
486471
487- values_to_solve = np .ones (
488- np .shape (self .curvature_reg_matrix )[0 ], dtype = bool
472+ ids_zeros = jnp .array (self .mapper_edge_pixel_list , dtype = int )
473+
474+ values_to_solve = jnp .ones (
475+ self .curvature_reg_matrix .shape [0 ], dtype = bool
489476 )
490- values_to_solve [ ids_zeros ] = False
477+ values_to_solve = values_to_solve . at [ ids_zeros ]. set ( False )
491478
492479 data_vector_input = self .data_vector [values_to_solve ]
493480
494481 curvature_reg_matrix_input = self .curvature_reg_matrix [
495482 values_to_solve , :
496483 ][:, values_to_solve ]
497484
498- solutions = np .zeros (np .shape (self .curvature_reg_matrix )[0 ])
499-
500- solutions [values_to_solve ] = (
501- inversion_util .reconstruction_positive_only_from (
502- data_vector = data_vector_input ,
503- curvature_reg_matrix = curvature_reg_matrix_input ,
504- settings = self .settings ,
505- )
485+ # Get the values to assign (must be a JAX array)
486+ reconstruction = inversion_util .reconstruction_positive_only_from (
487+ data_vector = data_vector_input ,
488+ curvature_reg_matrix = curvature_reg_matrix_input ,
489+ settings = self .settings ,
506490 )
491+
492+ # Allocate JAX array
493+ solutions = jnp .zeros (self .curvature_reg_matrix .shape [0 ])
494+
495+ # Get indices where True
496+ indices = jnp .where (values_to_solve )[0 ]
497+
498+ # Set reconstruction values at those indices
499+ solutions = solutions .at [indices ].set (reconstruction )
500+
507501 return solutions
502+
508503 else :
509- solutions = inversion_util .reconstruction_positive_only_from (
504+
505+ return inversion_util .reconstruction_positive_only_from (
510506 data_vector = self .data_vector ,
511507 curvature_reg_matrix = self .curvature_reg_matrix ,
512508 settings = self .settings ,
513509 )
514510
515- return solutions
516-
517511 mapper_param_range_list = self .param_range_list_from (cls = AbstractMapper )
518512
519513 return inversion_util .reconstruction_positive_negative_from (
@@ -522,81 +516,6 @@ def reconstruction(self) -> np.ndarray:
522516 mapper_param_range_list = mapper_param_range_list ,
523517 )
524518
525- # @cached_property
526- # @profile_func
527- # def reconstruction(self) -> np.ndarray:
528- # """
529- # Solve the linear system [F + reg_coeff*H] S = D -> S = [F + reg_coeff*H]^-1 D given by equation (12)
530- # of https://arxiv.org/pdf/astro-ph/0302587.pdf (Positive-Negative solution)
531- #
532- # ============================================================================================
533- #
534- # Solve the Eq.(2) of https://arxiv.org/pdf/astro-ph/0302587.pdf (Non-negative solution)
535- # Find non-negative solution that minimizes |Z * S - x|^2.
536- #
537- # We use fnnls (https://github.com/jvendrow/fnnls) to optimize the quadratic value. Two commonly used
538- # variables in the code are defined as follows:
539- # ZTZ := np.dot(Z.T, Z)
540- # ZTx := np.dot(Z.T, x)
541- # """
542- # if self.settings.use_positive_only_solver:
543- # """
544- # For the new implementation, we now need to take out the cols and rows of
545- # the curvature_reg_matrix that corresponds to the parameters we force to be 0.
546- # Similar for the data vector.
547- #
548- # What we actually doing is that we have set the correspoding cols of the Z to be 0.
549- # As the curvature_reg_matrix = ZTZ, so the cols and rows are all taken out.
550- # And the data_vector = ZTx, so the corresponding row is also taken out.
551- # """
552- #
553- # if self.settings.force_edge_pixels_to_zeros:
554- # if self.settings.force_edge_image_pixels_to_zeros:
555- # ids_zeros = np.unique(
556- # np.append(
557- # self.mapper_edge_pixel_list, self.mapper_zero_pixel_list
558- # )
559- # )
560- # else:
561- # ids_zeros = self.mapper_edge_pixel_list
562- #
563- # values_to_solve = np.ones(
564- # np.shape(self.curvature_reg_matrix)[0], dtype=bool
565- # )
566- # values_to_solve[ids_zeros] = False
567- #
568- # data_vector_input = self.data_vector[values_to_solve]
569- #
570- # curvature_reg_matrix_input = self.curvature_reg_matrix[
571- # values_to_solve, :
572- # ][:, values_to_solve]
573- #
574- # solutions = inversion_util.reconstruction_positive_only_from(
575- # data_vector=data_vector_input,
576- # curvature_reg_matrix=curvature_reg_matrix_input,
577- # settings=self.settings,
578- # )
579- #
580- # mask = values_to_solve.astype(bool)
581- #
582- # return solutions[mask]
583- # else:
584- # solutions = inversion_util.reconstruction_positive_only_from(
585- # data_vector=self.data_vector,
586- # curvature_reg_matrix=self.curvature_reg_matrix,
587- # settings=self.settings,
588- # )
589- #
590- # return solutions
591- #
592- # mapper_param_range_list = self.param_range_list_from(cls=AbstractMapper)
593- #
594- # return inversion_util.reconstruction_positive_negative_from(
595- # data_vector=self.data_vector,
596- # curvature_reg_matrix=self.curvature_reg_matrix,
597- # mapper_param_range_list=mapper_param_range_list,
598- # )
599-
600519 @cached_property
601520 @profile_func
602521 def reconstruction_reduced (self ) -> np .ndarray :
0 commit comments