@@ -93,10 +93,7 @@ def test_inequality_FDI_nodes():
9393 data ['l' ] = - 3
9494 data ['u' ] = 10
9595 randind = np .arange (0 ,len (data ))
96- # np.random.shuffle(randind)
97- # # data.loc[randind[:int(50*50*.5)],'val'] = np.nan
98- # np.random.shuffle(randind)
99- # data.loc[randind[:int(50*50*.1)],'val'] = 0
96+
10097 origin = np .array ([- 0.1 ,- 0.1 ,- 0.1 ])
10198 maximum = np .array ([1.1 ,1.1 ,1.1 ])
10299 nsteps = np .array ([20 ,20 ,20 ])
@@ -106,18 +103,30 @@ def test_inequality_FDI_nodes():
106103 interpolator .set_value_constraints (data [['X' ,'Y' ,'Z' ,'val' ,'w' ]].to_numpy ())
107104 interpolator .set_inequality_constraints (data [['X' ,'Y' ,'Z' ,'l' ,'u' ]].to_numpy ())
108105 interpolator ._setup_interpolator ()
109- # col = np.arange(0,interpolator.nx,dtype=int)
110- # col = np.tile(col, (interpolator.nx, 1)).T
111- # interpolator.add_inequality_constraints_to_matrix(np.eye(interpolator.nx),
112- # np.zeros(interpolator.nx)-4,
113- # np.zeros(interpolator.nx)+np.inf,
114- # col
115- # )
116106 interpolator .solve_system (solver = 'osqp' )
117107
118- # print(np.sum(interpolator.evaluate_value(data[['X','Y','Z']].to_numpy())-data[['val']].to_numpy())/len(data))
119- # assert np.sum(interpolator.evaluate_value(data[['X','Y','Z']].to_numpy())-data[['val']].to_numpy())/len(data) < 0.5
108+ def test_equality_FDI_nodes ():
109+ xy = np .array (np .meshgrid (np .linspace (0 ,1 ,50 ),np .linspace (0 ,1 ,50 ))).T .reshape (- 1 ,2 )
110+ xyz = np .hstack ([xy ,np .zeros ((xy .shape [0 ],1 ))])
111+ data = pd .DataFrame (xyz ,columns = ['X' ,'Y' ,'Z' ])
112+ data ['val' ] = np .sin (data ['X' ])
113+ data ['w' ] = 1
114+ data ['feature_name' ] = 'strati'
115+ origin = np .array ([- 0.1 ,- 0.1 ,- 0.1 ])
116+ maximum = np .array ([1.1 ,1.1 ,1.1 ])
117+ nsteps = np .array ([20 ,20 ,20 ])
118+ step_vector = (maximum - origin )/ nsteps
119+ grid = StructuredGrid (origin = origin ,nsteps = nsteps ,step_vector = step_vector )
120+ interpolator = FDI (grid )
121+ interpolator .set_value_constraints (data [['X' ,'Y' ,'Z' ,'val' ,'w' ]].to_numpy ())
122+
123+ node_idx = np .arange (0 ,interpolator .nx )[interpolator .support .nodes [:,2 ]> .9 ]
124+ interpolator .add_equality_constraints (node_idx ,np .ones (node_idx .shape [0 ]),name = 'top' )
125+ interpolator ._setup_interpolator ()
126+ interpolator .solve_system (solver = 'cg' )
127+
120128if __name__ == '__main__' :
121- test_inequality_FDI ()
122- test_inequality_FDI_nodes ()
123- test_FDI ()
129+ # test_inequality_FDI()
130+ # test_inequality_FDI_nodes()
131+ # test_FDI()
132+ test_equality_FDI_nodes ()
0 commit comments