@@ -16,7 +16,7 @@ class ConstantNormInterpolator:
1616 _type_
1717 _description_
1818 """
19- def __init__ (self , interpolator : DiscreteInterpolator ):
19+ def __init__ (self , interpolator : DiscreteInterpolator , basetype ):
2020 """Initialise the constant norm inteprolator
2121 with a discrete interpolator.
2222
@@ -25,10 +25,12 @@ def __init__(self, interpolator: DiscreteInterpolator):
2525 interpolator : DiscreteInterpolator
2626 The discrete interpolator to add constant norm to.
2727 """
28+ self .basetype = basetype
2829 self .interpolator = interpolator
2930 self .support = interpolator .support
3031 self .random_subset = False
3132 self .norm_length = 1.0
33+ self .n_iterations = 20
3234 def add_constant_norm (self , w :float ):
3335 """Add a constraint to the interpolator to constrain the norm of the gradient
3436 to be a set value
@@ -74,27 +76,33 @@ def solve_system(
7476 tol : Optional [float ] = None ,
7577 solver_kwargs : dict = {},
7678 ) -> bool :
77- """
79+ """Solve the system of equations iteratively for the constant norm interpolator.
7880
7981 Parameters
8082 ----------
8183 solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
82- _description_ , by default None
84+ Solver function or name , by default None
8385 tol : Optional[float], optional
84- _description_ , by default None
86+ Tolerance for the solver , by default None
8587 solver_kwargs : dict, optional
86- _description_ , by default {}
88+ Additional arguments for the solver , by default {}
8789
8890 Returns
8991 -------
9092 bool
91- _description_
93+ Success status of the solver
9294 """
9395 success = True
94- for i in range (20 ):
96+ for i in range (self . n_iterations ):
9597 if i > 0 :
9698 self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
97- success = self .interpolator .solve_system (solver = solver , tol = tol , solver_kwargs = solver_kwargs )
99+ # Ensure the interpolator is cast to P1Interpolator before calling solve_system
100+ if isinstance (self .interpolator , self .basetype ):
101+ success = self .basetype .solve_system (self .interpolator , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
102+ else :
103+ raise TypeError ("self.interpolator is not an instance of P1Interpolator" )
104+ if not success :
105+ break
98106 return success
99107
100108class ConstantNormP1Interpolator (P1Interpolator , ConstantNormInterpolator ):
@@ -116,7 +124,7 @@ def __init__(self, support):
116124 _description_
117125 """
118126 P1Interpolator .__init__ (self , support )
119- ConstantNormInterpolator .__init__ (self , self )
127+ ConstantNormInterpolator .__init__ (self , self , P1Interpolator )
120128
121129 def solve_system (
122130 self ,
@@ -129,24 +137,19 @@ def solve_system(
129137 Parameters
130138 ----------
131139 solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
132- _description_ , by default None
140+ Solver function or name , by default None
133141 tol : Optional[float], optional
134- _description_ , by default None
142+ Tolerance for the solver , by default None
135143 solver_kwargs : dict, optional
136- _description_ , by default {}
144+ Additional arguments for the solver , by default {}
137145
138146 Returns
139147 -------
140148 bool
141- _description_
149+ Success status of the solver
142150 """
143- success = True
144- for i in range (20 ):
151+ return ConstantNormInterpolator .solve_system (self , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
145152
146- if i > 0 :
147- self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
148- success = P1Interpolator .solve_system (self , solver , tol , solver_kwargs )
149- return success
150153class ConstantNormFDIInterpolator (FiniteDifferenceInterpolator , ConstantNormInterpolator ):
151154 """Constant norm interpolator using finite difference base interpolator
152155
@@ -166,7 +169,7 @@ def __init__(self, support):
166169 _description_
167170 """
168171 FiniteDifferenceInterpolator .__init__ (self , support )
169- ConstantNormInterpolator .__init__ (self , self )
172+ ConstantNormInterpolator .__init__ (self , self , FiniteDifferenceInterpolator )
170173 def solve_system (
171174 self ,
172175 solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
@@ -178,20 +181,15 @@ def solve_system(
178181 Parameters
179182 ----------
180183 solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
181- _description_ , by default None
184+ Solver function or name , by default None
182185 tol : Optional[float], optional
183- _description_ , by default None
186+ Tolerance for the solver , by default None
184187 solver_kwargs : dict, optional
185- _description_ , by default {}
188+ Additional arguments for the solver , by default {}
186189
187190 Returns
188191 -------
189192 bool
190- _description_
193+ Success status of the solver
191194 """
192- success = True
193- for i in range (20 ):
194- if i > 0 :
195- self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
196- success = FiniteDifferenceInterpolator .solve_system (self , solver , tol , solver_kwargs )
197- return success
195+ return ConstantNormInterpolator .solve_system (self , solver = solver , tol = tol , solver_kwargs = solver_kwargs )
0 commit comments