@@ -61,9 +61,42 @@ def curvature_matrix_with_added_to_diag_from(
6161 curvature_matrix
6262 The curvature matrix which is being constructed in order to solve a linear system of equations.
6363 """
64- return curvature_matrix .at [
65- no_regularization_index_list , no_regularization_index_list
66- ].add (value )
64+ try :
65+ return curvature_matrix .at [
66+ no_regularization_index_list , no_regularization_index_list
67+ ].add (value )
68+ except AttributeError :
69+ return curvature_matrix_with_added_to_diag_from_numba (
70+ curvature_matrix = curvature_matrix ,
71+ value = value ,
72+ no_regularization_index_list = no_regularization_index_list ,
73+ )
74+
75+ @numba_util .jit ()
76+ def curvature_matrix_with_added_to_diag_from_numba (
77+ curvature_matrix : np .ndarray ,
78+ value : float ,
79+ no_regularization_index_list : Optional [List ] = None ,
80+ ) -> np .ndarray :
81+ """
82+ It is common for the `curvature_matrix` computed to not be positive-definite, leading for the inversion
83+ via `np.linalg.solve` to fail and raise a `LinAlgError`.
84+
85+ In many circumstances, adding a small numerical value of `1.0e-8` to the diagonal of the `curvature_matrix`
86+ makes it positive definite, such that the inversion is performed without raising an error.
87+
88+ This function adds this numerical value to the diagonal of the curvature matrix.
89+
90+ Parameters
91+ ----------
92+ curvature_matrix
93+ The curvature matrix which is being constructed in order to solve a linear system of equations.
94+ """
95+
96+ for i in no_regularization_index_list :
97+ curvature_matrix [i , i ] += value
98+
99+ return curvature_matrix
67100
68101
69102def curvature_matrix_mirrored_from (
0 commit comments