1818from autoarray .util import misc_util
1919from autoarray .inversion .inversion import inversion_util
2020
21+
2122class AbstractInversion :
2223 def __init__ (
2324 self ,
2425 dataset : Union [Imaging , Interferometer , DatasetInterface ],
2526 linear_obj_list : List [LinearObj ],
2627 settings : SettingsInversion = SettingsInversion (),
2728 preloads : Preloads = None ,
28- xp = np
29+ xp = np ,
2930 ):
3031 """
3132 An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -75,8 +76,6 @@ def __init__(
7576
7677 self ._xp = xp
7778
78-
79-
8079 @property
8180 def data (self ):
8281 return self .dataset .data
@@ -333,10 +332,15 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
333332 """
334333 if self ._xp .__name__ .startswith ("jax" ):
335334 from jax .scipy .linalg import block_diag
335+
336336 return block_diag (
337- * [linear_obj .regularization_matrix for linear_obj in self .linear_obj_list ]
337+ * [
338+ linear_obj .regularization_matrix
339+ for linear_obj in self .linear_obj_list
340+ ]
338341 )
339342 from scipy .linalg import block_diag
343+
340344 return block_diag (
341345 * [linear_obj .regularization_matrix for linear_obj in self .linear_obj_list ]
342346 )
@@ -448,7 +452,7 @@ def reconstruction(self) -> np.ndarray:
448452 data_vector = data_vector ,
449453 curvature_reg_matrix = curvature_reg_matrix ,
450454 settings = self .settings ,
451- xp = self ._xp
455+ xp = self ._xp ,
452456 )
453457 )
454458
@@ -471,13 +475,13 @@ def reconstruction(self) -> np.ndarray:
471475 data_vector = self .data_vector ,
472476 curvature_reg_matrix = self .curvature_reg_matrix ,
473477 settings = self .settings ,
474- xp = self ._xp
478+ xp = self ._xp ,
475479 )
476480
477481 return inversion_util .reconstruction_positive_negative_from (
478482 data_vector = self .data_vector ,
479483 curvature_reg_matrix = self .curvature_reg_matrix ,
480- xp = self ._xp
484+ xp = self ._xp ,
481485 )
482486
483487 @property
@@ -640,7 +644,9 @@ def regularization_term(self) -> float:
640644
641645 return self ._xp .matmul (
642646 self .reconstruction_reduced .T ,
643- self ._xp .matmul (self .regularization_matrix_reduced , self .reconstruction_reduced ),
647+ self ._xp .matmul (
648+ self .regularization_matrix_reduced , self .reconstruction_reduced
649+ ),
644650 )
645651
646652 @property
@@ -654,7 +660,11 @@ def log_det_curvature_reg_matrix_term(self) -> float:
654660 return 0.0
655661
656662 return 2.0 * self ._xp .sum (
657- self ._xp .log (self ._xp .diag (self ._xp .linalg .cholesky (self .curvature_reg_matrix_reduced )))
663+ self ._xp .log (
664+ self ._xp .diag (
665+ self ._xp .linalg .cholesky (self .curvature_reg_matrix_reduced )
666+ )
667+ )
658668 )
659669
660670 @property
@@ -675,7 +685,11 @@ def log_det_regularization_matrix_term(self) -> float:
675685 return 0.0
676686
677687 return 2.0 * self ._xp .sum (
678- self ._xp .log (self ._xp .diag (self ._xp .linalg .cholesky (self .regularization_matrix_reduced )))
688+ self ._xp .log (
689+ self ._xp .diag (
690+ self ._xp .linalg .cholesky (self .regularization_matrix_reduced )
691+ )
692+ )
679693 )
680694
681695 @property
@@ -738,7 +752,9 @@ def regularization_weights_from(self, index: int) -> np.ndarray:
738752
739753 return np .zeros ((pixels ,))
740754
741- return regularization .regularization_weights_from (linear_obj = linear_obj , xp = self ._xp )
755+ return regularization .regularization_weights_from (
756+ linear_obj = linear_obj , xp = self ._xp
757+ )
742758
743759 @property
744760 def regularization_weights_mapper_dict (self ) -> Dict [LinearObj , np .ndarray ]:
0 commit comments