1+ import numpy as np
2+
3+ from LoopStructural .interpolators ._discrete_interpolator import DiscreteInterpolator
4+ from LoopStructural .interpolators ._finite_difference_interpolator import FiniteDifferenceInterpolator
5+ from ._p1interpolator import P1Interpolator
6+ from typing import Optional , Union , Callable
7+ from scipy import sparse
8+ from LoopStructural .utils import rng
9+
10+ class ConstantNormInterpolator :
11+ """Adds a non linear constraint to an interpolator to constrain
12+ the norm of the gradient to be a set value.
13+
14+ Returns
15+ -------
16+ _type_
17+ _description_
18+ """
19+ def __init__ (self , interpolator : DiscreteInterpolator ):
20+ """Initialise the constant norm inteprolator
21+ with a discrete interpolator.
22+
23+ Parameters
24+ ----------
25+ interpolator : DiscreteInterpolator
26+ The discrete interpolator to add constant norm to.
27+ """
28+ self .interpolator = interpolator
29+ self .support = interpolator .support
30+ self .random_subset = False
31+ self .norm_length = 1.0
32+ def add_constant_norm (self , w :float ):
33+ """Add a constraint to the interpolator to constrain the norm of the gradient
34+ to be a set value
35+
36+ Parameters
37+ ----------
38+ w : float
39+ weighting of the constraint
40+ """
41+ if "constant norm" in self .interpolator .constraints :
42+ _ = self .interpolator .constraints .pop ("constant norm" )
43+
44+ element_indices = np .arange (self .support .elements .shape [0 ])
45+ if self .random_subset :
46+ rng .shuffle (element_indices )
47+ element_indices = element_indices [: int (0.1 * self .support .elements .shape [0 ])]
48+ vertices , gradient , elements , inside = self .support .get_element_gradient_for_location (
49+ self .support .barycentre [element_indices ]
50+ )
51+
52+ t_g = gradient [:, :, :]
53+ # t_n = gradient[self.support.shared_element_relationships[:, 1], :, :]
54+ v_t = np .einsum (
55+ "ijk,ik->ij" ,
56+ t_g ,
57+ self .interpolator .c [self .support .elements [elements ]],
58+ )
59+
60+ v_t = v_t / np .linalg .norm (v_t , axis = 1 )[:, np .newaxis ]
61+ A1 = np .einsum ("ij,ijk->ik" , v_t , t_g )
62+
63+ b = np .zeros (A1 .shape [0 ]) + self .norm_length
64+ idc = np .hstack (
65+ [
66+ self .support .elements [elements ],
67+ ]
68+ )
69+ self .interpolator .add_constraints_to_least_squares (A1 , b , idc , w = w , name = "constant norm" )
70+
71+ def solve_system (
72+ self ,
73+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
74+ tol : Optional [float ] = None ,
75+ solver_kwargs : dict = {},
76+ ) -> bool :
77+ """
78+
79+ Parameters
80+ ----------
81+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
82+ _description_, by default None
83+ tol : Optional[float], optional
84+ _description_, by default None
85+ solver_kwargs : dict, optional
86+ _description_, by default {}
87+
88+ Returns
89+ -------
90+ bool
91+ _description_
92+ """
93+ for i in range (20 ):
94+ if i > 0 :
95+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
96+ success = self .interpolator .solve_system (solver = solver , tol = tol , solver_kwargs = solver_kwargs )
97+ return True
98+
99+ class ConstantNormP1Interpolator (P1Interpolator , ConstantNormInterpolator ):
100+ """Constant norm interpolator using P1 base interpolator
101+
102+ Parameters
103+ ----------
104+ P1Interpolator : class
105+ The P1Interpolator class.
106+ ConstantNormInterpolator : class
107+ The ConstantNormInterpolator class.
108+ """
109+ def __init__ (self , support ):
110+ """Initialise the constant norm P1 interpolator.
111+
112+ Parameters
113+ ----------
114+ support : _type_
115+ _description_
116+ """
117+ P1Interpolator .__init__ (self , support )
118+ ConstantNormInterpolator .__init__ (self , self )
119+
120+ def solve_system (
121+ self ,
122+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
123+ tol : Optional [float ] = None ,
124+ solver_kwargs : dict = {},
125+ ) -> bool :
126+ """Solve the system of equations for the constant norm P1 interpolator.
127+
128+ Parameters
129+ ----------
130+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
131+ _description_, by default None
132+ tol : Optional[float], optional
133+ _description_, by default None
134+ solver_kwargs : dict, optional
135+ _description_, by default {}
136+
137+ Returns
138+ -------
139+ bool
140+ _description_
141+ """
142+ success = True
143+ for i in range (20 ):
144+
145+ if i > 0 :
146+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
147+ success = P1Interpolator .solve_system (self , solver , tol , solver_kwargs )
148+ return success
149+ class ConstantNormFDIInterpolator (FiniteDifferenceInterpolator , ConstantNormInterpolator ):
150+ """Constant norm interpolator using finite difference base interpolator
151+
152+ Parameters
153+ ----------
154+ FiniteDifferenceInterpolator : class
155+ The FiniteDifferenceInterpolator class.
156+ ConstantNormInterpolator : class
157+ The ConstantNormInterpolator class.
158+ """
159+ def __init__ (self , support ):
160+ """Initialise the constant norm finite difference interpolator.
161+
162+ Parameters
163+ ----------
164+ support : _type_
165+ _description_
166+ """
167+ FiniteDifferenceInterpolator .__init__ (self , support )
168+ ConstantNormInterpolator .__init__ (self , self )
169+ def solve_system (
170+ self ,
171+ solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
172+ tol : Optional [float ] = None ,
173+ solver_kwargs : dict = {},
174+ ) -> bool :
175+ """Solve the system of equations for the constant norm finite difference interpolator.
176+
177+ Parameters
178+ ----------
179+ solver : Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]], optional
180+ _description_, by default None
181+ tol : Optional[float], optional
182+ _description_, by default None
183+ solver_kwargs : dict, optional
184+ _description_, by default {}
185+
186+ Returns
187+ -------
188+ bool
189+ _description_
190+ """
191+ success = True
192+ for i in range (20 ):
193+ if i > 0 :
194+ self .add_constant_norm (w = (0.1 * i ) ** 2 + 0.01 )
195+ success = FiniteDifferenceInterpolator .solve_system (self , solver , tol , solver_kwargs )
196+ return success
0 commit comments