Skip to content

Commit ef109d1

Browse files
author
Lachlan Grose
committed
test: ✅ new test for equality constraints
Added a new test for equality constraints using FDI. Need to add more options for other solver.
1 parent ea34c5e commit ef109d1

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

LoopStructural/interpolators/discrete_interpolator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def add_equality_constraints(self, node_idx, values, name="undefined"):
277277
gi[self.region] = np.arange(0, self.nx)
278278
idc = gi[node_idx]
279279
outside = ~(idc == -1)
280+
280281
self.equal_constraints[name] = {
281282
"A": np.ones(idc[outside].shape[0]),
282283
"B": values[outside],
@@ -434,15 +435,22 @@ def build_matrix(self, square=True, damp=0.0, ie=False):
434435
# c are the node values and y are the
435436
# lagrange multipliers#
436437
nc = 0
438+
a = []
439+
rows = []
440+
cols = []
441+
b = []
437442
for c in self.equal_constraints.values():
438443
b.extend((c["B"]).tolist())
444+
aa = c["A"].flatten()
439445
mask = aa == 0
440-
a.extend(c["A"].flatten()[~mask].tolist())
446+
a.extend(aa[~mask].tolist())
441447
rows.extend(c["row"].flatten()[~mask].tolist())
442448
cols.extend(c["col"].flatten()[~mask].tolist())
449+
443450
C = coo_matrix(
451+
444452
(np.array(a), (np.array(rows), cols)),
445-
shape=(self.eq_const_c_, self.nx),
453+
shape=(self.eq_const_c, self.nx),
446454
dtype=float,
447455
).tocsr()
448456

tests/unit_tests/interpolator/test_solvers.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,7 @@ def test_inequality_FDI_nodes():
9393
data['l'] = -3
9494
data['u'] = 10
9595
randind = np.arange(0,len(data))
96-
# np.random.shuffle(randind)
97-
# # data.loc[randind[:int(50*50*.5)],'val'] = np.nan
98-
# np.random.shuffle(randind)
99-
# data.loc[randind[:int(50*50*.1)],'val'] = 0
96+
10097
origin = np.array([-0.1,-0.1,-0.1])
10198
maximum = np.array([1.1,1.1,1.1])
10299
nsteps = np.array([20,20,20])
@@ -106,18 +103,30 @@ def test_inequality_FDI_nodes():
106103
interpolator.set_value_constraints(data[['X','Y','Z','val','w']].to_numpy())
107104
interpolator.set_inequality_constraints(data[['X','Y','Z','l','u']].to_numpy())
108105
interpolator._setup_interpolator()
109-
# col = np.arange(0,interpolator.nx,dtype=int)
110-
# col = np.tile(col, (interpolator.nx, 1)).T
111-
# interpolator.add_inequality_constraints_to_matrix(np.eye(interpolator.nx),
112-
# np.zeros(interpolator.nx)-4,
113-
# np.zeros(interpolator.nx)+np.inf,
114-
# col
115-
# )
116106
interpolator.solve_system(solver='osqp')
117107

118-
# print(np.sum(interpolator.evaluate_value(data[['X','Y','Z']].to_numpy())-data[['val']].to_numpy())/len(data))
119-
# assert np.sum(interpolator.evaluate_value(data[['X','Y','Z']].to_numpy())-data[['val']].to_numpy())/len(data) < 0.5
108+
def test_equality_FDI_nodes():
109+
xy = np.array(np.meshgrid(np.linspace(0,1,50),np.linspace(0,1,50))).T.reshape(-1,2)
110+
xyz = np.hstack([xy,np.zeros((xy.shape[0],1))])
111+
data = pd.DataFrame(xyz,columns=['X','Y','Z'])
112+
data['val'] = np.sin(data['X'])
113+
data['w'] = 1
114+
data['feature_name'] = 'strati'
115+
origin = np.array([-0.1,-0.1,-0.1])
116+
maximum = np.array([1.1,1.1,1.1])
117+
nsteps = np.array([20,20,20])
118+
step_vector = (maximum-origin)/nsteps
119+
grid = StructuredGrid(origin=origin,nsteps=nsteps,step_vector=step_vector)
120+
interpolator = FDI(grid)
121+
interpolator.set_value_constraints(data[['X','Y','Z','val','w']].to_numpy())
122+
123+
node_idx = np.arange(0,interpolator.nx)[interpolator.support.nodes[:,2]>.9]
124+
interpolator.add_equality_constraints(node_idx,np.ones(node_idx.shape[0]),name='top')
125+
interpolator._setup_interpolator()
126+
interpolator.solve_system(solver='cg')
127+
120128
if __name__ == '__main__':
121-
test_inequality_FDI()
122-
test_inequality_FDI_nodes()
123-
test_FDI()
129+
# test_inequality_FDI()
130+
# test_inequality_FDI_nodes()
131+
# test_FDI()
132+
test_equality_FDI_nodes()

0 commit comments

Comments
 (0)