88# Internal imports 
99from  popt .misc_tools  import  optim_tools  as  ot 
1010from  pipt .misc_tools  import  analysis_tools  as  at 
11- from  ensemble .ensemble  import  Ensemble  as  PETEnsemble 
11+ from  ensemble .ensemble  import  Ensemble  as  SupEnsemble 
1212from  simulator .simple_models  import  noSimulation 
1313
14- class  EnsembleOptimizationBaseClass (PETEnsemble ):
14+ __all__  =  ['EnsembleOptimizationBaseClass' ]
15+ 
16+ class  EnsembleOptimizationBaseClass (SupEnsemble ):
1517    ''' 
1618    Base class for the popt ensemble 
1719    ''' 
@@ -33,61 +35,64 @@ def __init__(self, options, simulator, objective):
3335        else :
3436            sim  =  simulator 
3537
36-         # Initialize PETEnsemble  
38+         # Initialize the PET Ensemble  
3739        super ().__init__ (options , sim )
3840
3941        # Unpack some options 
4042        self .save_prediction  =  options .get ('save_prediction' , None )
4143        self .num_models       =  options .get ('num_models' , 1 )
4244        self .transform        =  options .get ('transform' , False )
4345        self .num_samples      =  self .ne 
44-         
45-         # Define some variables 
46+ 
47+         # Set objective function (callable) 
48+         self .obj_func  =  objective 
49+         self .state_func_values  =  None 
50+         self .ens_func_values  =  None 
51+ 
52+         # Initialize prior 
53+         self ._initialize_state_info () # Initialize cov, bounds, and state 
54+         self ._scale_state () # Scale self.state to [0, 1] if transform is True 
55+ 
56+     def  _initialize_state_info (self ):
57+         ''' 
58+         Initialize covariance and bounds based on prior information. 
59+         ''' 
60+         self .cov  =  np .array ([])
4661        self .lb  =  []
4762        self .ub  =  []
4863        self .bounds  =  []
49-         self .cov  =  np .array ([])
50- 
51-         # Get bounds and varaince, and initialize state 
64+         
5265        for  key  in  self .prior_info .keys ():
5366            variable  =  self .prior_info [key ]
54- 
67+              
5568            # mean 
5669            self .state [key ] =  np .asarray (variable ['mean' ])
5770
5871            # Covariance 
5972            dim  =  self .state [key ].size 
60-             cov  =  variable ['variance' ]* np .ones (dim )
61- 
73+             var  =  variable ['variance' ]* np .ones (dim )
74+          
6275            if  'limits'  in  variable .keys ():
6376                lb , ub  =  variable ['limits' ]
64-                 self .lb (lb )
65-                 self .ub (ub )
66- 
67-                 # transform cov  to [0, 1] if transform is True 
77+                 self .lb . append (lb )
78+                 self .ub . append (ub )
79+          
80+                 # transform var  to [0, 1] if transform is True 
6881                if  self .transform :
69-                     cov  =  np .clip (cov / (ub  -  lb )** 2 , 0 , 1 , out = cov )
82+                     var  =  var / (ub  -  lb )** 2 
83+                     var  =  np .clip (var , 0 , 1 , out = var )
7084                    self .bounds  +=  dim * [(0 , 1 )]
7185                else :
7286                    self .bounds  +=  dim * [(lb , ub )]
7387            else :
7488                self .bounds  +=  dim * [(None , None )]
7589
7690            # Add to covariance 
77-             self .cov  =  np .append (self .cov , cov )
78-             
91+             self .cov  =  np .append (self .cov , var )
92+             self .dim  =  self .cov .shape [0 ]
93+ 
7994        # Make cov full covariance matrix 
8095        self .cov  =  np .diag (self .cov )
81- 
82-         # Scale the state to [0, 1] if transform is True 
83-         self ._scale_state ()
84- 
85-         # Set objective function (callable) 
86-         self .obj_func  =  objective 
87- 
88-         # Objective function values 
89-         self .state_func_values  =  None 
90-         self .ens_func_values  =  None 
9196
9297    def  get_state (self ):
9398        """ 
@@ -98,6 +103,15 @@ def get_state(self):
98103        """ 
99104        return  ot .aug_optim_state (self .state , list (self .state .keys ()))
100105
106+     def  get_cov (self ):
107+         """ 
108+         Returns 
109+         ------- 
110+         cov : numpy.ndarray 
111+             Covariance matrix, shape (number of controls, number of controls) 
112+         """ 
113+         return  self .cov 
114+     
101115    def  vec_to_state (self , x ):
102116        """ 
103117        Converts a control vector to the internal state representation. 
@@ -114,7 +128,7 @@ def get_bounds(self):
114128
115129        return  self .bounds 
116130
117-     def  function (self , x , * args ):
131+     def  function (self , x , * args ,  ** kwargs ):
118132        """ 
119133        This is the main function called during optimization. 
120134
@@ -130,29 +144,41 @@ def function(self, x, *args):
130144        """ 
131145        self ._aux_input ()
132146
133-         if  len (x .shape ) ==  1 :
134-             self .ne  =  self .num_models 
135-         else :
136-             self .ne  =  x .shape [1 ]
147+         # check for ensmble 
148+         if  len (x .shape ) ==  1 : self .ne  =  self .num_models 
149+         else : self .ne  =  x .shape [1 ]
137150
138-         # convert x to state 
139-         self .state  =  self .vec_to_state (x )   # go from nparray to dict 
151+         # convert x (nparray)  to state (dict)  
152+         self .state  =  self .vec_to_state (x )
140153
141154        # run the simulation 
142155        self ._invert_scale_state ()  # ensure that state is in [lb,ub] 
156+         self ._set_multilevel_state (self .state , x )  # set multilevel state if applicable 
143157        run_success  =  self .calc_prediction (save_prediction = self .save_prediction )  # calculate flow data 
158+         self ._set_multilevel_state (self .state , x ) # For some reason this has to be done again after calc_prediction 
144159        self ._scale_state ()  # scale back to [0, 1] 
160+ 
161+         # Evaluate the objective function 
145162        if  run_success :
146-             func_values  =  self .obj_func (self .pred_data , self .sim .input_dict , self .sim .true_order )
163+             func_values  =  self .obj_func (
164+                 self .pred_data , 
165+                 input_dict = self .sim .input_dict ,
166+                 true_order = self .sim .true_order , 
167+                 ** kwargs 
168+             )
147169        else :
148170            func_values  =  np .inf   # the simulations have crashed 
149171
150-         if  len (x .shape ) ==  1 :
151-             self .state_func_values  =  func_values 
152-         else :
153-             self .ens_func_values  =  func_values 
172+         if  len (x .shape ) ==  1 : self .state_func_values  =  func_values 
173+         else : self .ens_func_values  =  func_values 
154174
155175        return  func_values 
176+     
177+     def  _set_multilevel_state (self , state , x ):
178+         if  'multilevel'  in  self .keys_en .keys () and  len (x .shape ) >  1 :  
179+             en_size  =  ot .get_list_element (self .keys_en ['multilevel' ], 'en_size' )
180+             self .state  =  ot .toggle_ml_state (self .state , en_size )
181+ 
156182
157183    def  _aux_input (self ):
158184        """ 
0 commit comments