Skip to content

Commit ea34c5e

Browse files
author
Lachlan Grose
committed
formatting discrete interpolatory and
removing figure from solver test
1 parent 6c75379 commit ea34c5e

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

LoopStructural/interpolators/discrete_interpolator.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def add_equality_constraints(self, node_idx, values, name="undefined"):
285285
"row": np.arange(self.eq_const_c, self.eq_const_c + idc[outside].shape[0]),
286286
}
287287
self.eq_const_c += idc[outside].shape[0]
288-
289288

290289
def add_non_linear_constraints(self, nonlinear_constraint):
291290
self.non_linear_constraints.append(nonlinear_constraint)
@@ -310,9 +309,9 @@ def add_inequality_constraints_to_matrix(self, A, l, u, idc, name="undefined"):
310309
311310
"""
312311
# map from mesh node index to region node index
313-
gi = np.zeros(self.support.n_nodes,dtype=int)
312+
gi = np.zeros(self.support.n_nodes, dtype=int)
314313
gi[:] = -1
315-
gi[self.region] = np.arange(0, self.nx,dtype=int)
314+
gi[self.region] = np.arange(0, self.nx, dtype=int)
316315
idc = gi[idc]
317316
rows = np.arange(self.ineq_const_c, self.ineq_const_c + idc.shape[0])
318317
rows = np.tile(rows, (A.shape[-1], 1)).T
@@ -438,7 +437,7 @@ def build_matrix(self, square=True, damp=0.0, ie=False):
438437
for c in self.equal_constraints.values():
439438
b.extend((c["B"]).tolist())
440439
mask = aa == 0
441-
a.extend(c["A"]).flatten()[~mask].tolist())
440+
a.extend(c["A"].flatten()[~mask].tolist())
442441
rows.extend(c["row"].flatten()[~mask].tolist())
443442
cols.extend(c["col"].flatten()[~mask].tolist())
444443
C = coo_matrix(
@@ -487,7 +486,7 @@ def build_matrix(self, square=True, damp=0.0, ie=False):
487486
return ATA, ATB, Aie.T.dot(Aie), Aie.T.dot(uie), Aie.T.dot(lie)
488487
return ATA, ATB
489488

490-
def _solve_osqp(self, P, A, q, l, u,mkl=False):
489+
def _solve_osqp(self, P, A, q, l, u, mkl=False):
491490

492491
try:
493492
import osqp
@@ -521,15 +520,24 @@ def _solve_osqp(self, P, A, q, l, u,mkl=False):
521520

522521
# Setup workspace
523522
# osqp likes csc matrices
524-
linsys_solver='qdldl'
523+
linsys_solver = "qdldl"
525524
if mkl:
526-
linsys_solver='mkl pardiso'
527-
525+
linsys_solver = "mkl pardiso"
526+
528527
try:
529-
prob.setup(P.tocsc(), np.array(q), A.tocsc(), np.array(u), np.array(l),linsys_solver=linsys_solver)
528+
prob.setup(
529+
P.tocsc(),
530+
np.array(q),
531+
A.tocsc(),
532+
np.array(u),
533+
np.array(l),
534+
linsys_solver=linsys_solver,
535+
)
530536
except ValueError:
531537
if mkl:
532-
logger.error('MKL solver library path not correct. Please add to LD_LIBRARY_PATH')
538+
logger.error(
539+
"MKL solver library path not correct. Please add to LD_LIBRARY_PATH"
540+
)
533541
raise LoopImportError("Cannot import MKL pardiso")
534542
res = prob.solve()
535543
return res.x
@@ -725,7 +733,9 @@ def _solve(self, solver="cg", **kwargs):
725733
logger.warning("Using external solver")
726734
self.c[self.region] = kwargs["external"](A, B)[: self.nx]
727735
if solver == "osqp":
728-
self.c[self.region] = self._solve_osqp(P, A, q, l, u,mkl=kwargs.get('mkl',False)) # , **kwargs)
736+
self.c[self.region] = self._solve_osqp(
737+
P, A, q, l, u, mkl=kwargs.get("mkl", False)
738+
) # , **kwargs)
729739
# check solution is not nan
730740
# self.support.properties[self.propertyname] = self.c
731741
if np.all(self.c == np.nan):

tests/unit_tests/interpolator/test_solvers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def test_FDI():
4848
interpolator.set_value_constraints(data[['X','Y','Z','val','w']].to_numpy())
4949
interpolator._setup_interpolator()
5050
interpolator.solve_system()
51-
plt.imshow(interpolator.evaluate_value(xyz).reshape((50,50)).T)
52-
plt.savefig('normal.png')
5351
assert np.sum(interpolator.evaluate_value(data[['X','Y','Z']].to_numpy())-data[['val']].to_numpy())/len(data) < 0.5
5452
def test_inequality_FDI():
5553
xy = np.array(np.meshgrid(np.linspace(0,1,50),np.linspace(0,1,50))).T.reshape(-1,2)
@@ -90,7 +88,6 @@ def test_inequality_FDI_nodes():
9088
xyz = np.hstack([xy,np.zeros((xy.shape[0],1))])
9189
data = pd.DataFrame(xyz,columns=['X','Y','Z'])
9290
data['val'] = np.sin(data['X'])
93-
print(data['val'].max(),data['val'].min())
9491
data['w'] = 1
9592
data['feature_name'] = 'strati'
9693
data['l'] = -3

0 commit comments

Comments
 (0)