Skip to content

Commit ec749aa

Browse files
MathiasMNilsenMathias Methlie NilsenMathias Methlie Nilsen
authored
Some changes to input format and cleaned up popt structure (resubmitted) (#103)
* update to TrustRegion * some design changes to TrustRegion * decoupled GenOpt from Ensemble * decoupled GenOpt from Ensemble * Cleaned up code duplication and renamed some stuff * comments * improved and cleaned up LineSearch * added Newton-CG to LineSearch * improved logging for LineSearch * dummy message for github check * empty commit * udpate BFGS to skip update if negetive curvature * made input more elegant and cleaned up popt structure * changed @cache to @lru_cache --------- Co-authored-by: Mathias Methlie Nilsen <mani@cno-0006.ad.norceresearch.no> Co-authored-by: Mathias Methlie Nilsen <mani@bgo-1714.ad.norceresearch.no>
1 parent d19ff1b commit ec749aa

File tree

19 files changed

+862
-751
lines changed

19 files changed

+862
-751
lines changed

ensemble/ensemble.py

Lines changed: 111 additions & 238 deletions
Large diffs are not rendered by default.

input_output/organize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from copy import deepcopy
44
import csv
55
import datetime as dt
6+
import pandas as pd
67

78

89
class Organize_input():
@@ -109,6 +110,13 @@ def _org_report(self):
109110
pred_prim.extend(csv_data)
110111
self.keys_fwd['reportpoint'] = pred_prim
111112

113+
elif isinstance(self.keys_fwd['reportpoint'], dict):
114+
self.keys_fwd['reportpoint'] = pd.date_range(**self.keys_fwd['reportpoint']).to_pydatetime().tolist()
115+
116+
else:
117+
pass
118+
119+
112120
# Check if assimindex is given as a csv file. If so, we read and make a potential 2D list (if sequential).
113121
if 'assimindex' in self.keys_pr:
114122
if isinstance(self.keys_pr['assimindex'], str) and self.keys_pr['assimindex'].endswith('.csv'):

input_output/read_config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def ndarray_constructor(loader, node):
5151
y = yaml.load(fid, Loader=FullLoader)
5252

5353
# Check for dataassim and fwdsim
54+
if 'ensemble' in y.keys():
55+
keys_en = y['ensemble']
56+
check_mand_keywords_en(keys_en)
57+
else:
58+
keys_en = None
59+
5460
if 'optim' in y.keys():
5561
keys_pr = y['optim']
5662
check_mand_keywords_opt(keys_pr)
@@ -59,16 +65,17 @@ def ndarray_constructor(loader, node):
5965
check_mand_keywords_da(keys_pr)
6066
else:
6167
raise KeyError
68+
6269
if 'fwdsim' in y.keys():
6370
keys_fwd = y['fwdsim']
6471
else:
6572
raise KeyError
6673

6774
# Organize keywords
68-
org = Organize_input(keys_pr, keys_fwd)
75+
org = Organize_input(keys_pr, keys_fwd, keys_en)
6976
org.organize()
7077

71-
return org.get_keys_pr(), org.get_keys_fwd()
78+
return org.get_keys_pr(), org.get_keys_fwd(), org.get_keys_en()
7279

7380

7481
def convert_txt_to_toml(init_file):

pipt/loop/assimilation.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -301,32 +301,20 @@ def _ext_max_iter(self):
301301
- ST 7/6-16
302302
"""
303303
if 'iteration' in self.ensemble.keys_da:
304-
# Make sure ITERATION is a list
305-
if not isinstance(self.ensemble.keys_da['iteration'][0], list):
306-
iter_opts = [self.ensemble.keys_da['iteration']]
307-
else:
308-
iter_opts = self.ensemble.keys_da['iteration']
309-
304+
iter_opts = dict(self.ensemble.keys_da['iteration'])
310305
# Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
311-
assert 'max_iter' in list(
312-
zip(*iter_opts))[0], 'MAX_ITER has not been given in ITERATION!'
313-
314-
# Extract max. iter
315-
max_iter = [item[1] for item in iter_opts if item[0] == 'max_iter'][0]
306+
try:
307+
max_iter = iter_opts['max_iter']
308+
except KeyError:
309+
raise AssertionError('MAX_ITER has not been given in ITERATION')
316310

317311
elif 'mda' in self.ensemble.keys_da:
318-
# Make sure ITERATION is a list
319-
if not isinstance(self.ensemble.keys_da['mda'][0], list):
320-
iter_opts = [self.ensemble.keys_da['mda']]
321-
else:
322-
iter_opts = self.ensemble.keys_da['mda']
323-
324-
# Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
325-
assert 'tot_assim_steps' in list(
326-
zip(*iter_opts))[0], 'TOT_ASSIM_STEPS has not been given in MDA!'
327-
328-
# Extract max. iter
329-
max_iter = [item[1] for item in iter_opts if item[0] == 'tot_assim_steps'][0]
312+
iter_opts = dict(self.ensemble.keys_da['mda'])
313+
# Check if 'tot_assim_steps' has been given; if not, raise error (mandatory in MDA)
314+
try:
315+
max_iter = iter_opts['tot_assim_steps']
316+
except KeyError:
317+
raise AssertionError('TOT_ASSIM_STEPS has not been given in MDA!')
330318

331319
else:
332320
max_iter = 1

pipt/update_schemes/enrml.py

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -250,36 +250,22 @@ def _ext_iter_param(self):
250250
file. These parameters include convergence tolerances and parameters for the damping parameter. Default
251251
values for these parameters have been given here, if they are not provided in ITERATION.
252252
"""
253-
254-
# Predefine all the default values
255-
self.data_misfit_tol = 0.01
256-
self.step_tol = 0.01
257-
self.lam = 100
258-
self.lam_max = 1e10
259-
self.lam_min = 0.01
260-
self.gamma = 5
261-
self.trunc_energy = 0.95
253+
try:
254+
options = dict(self.keys_da['iteration'])
255+
except:
256+
options = dict([self.keys_da['iteration']])
257+
258+
# unpack options
259+
self.data_misfit_tol = options.get('data_misfit_tol', 0.01)
260+
self.trunc_energy = options.get('energy', 0.95)
261+
self.step_tol = options.get('step_tol', 0.01)
262+
self.lam = options.get('lambda', 100)
263+
self.lam_max = options.get('lambda_max', 1e10)
264+
self.lam_min = options.get('lambda_min', 0.01)
265+
self.gamma = options.get('lambda_factor', 5)
262266
self.iteration = 0
263267

264-
# Loop over options in ITERATION and extract the parameters we want
265-
for i, opt in enumerate(list(zip(*self.keys_da['iteration']))[0]):
266-
if opt == 'data_misfit_tol':
267-
self.data_misfit_tol = self.keys_da['iteration'][i][1]
268-
if opt == 'step_tol':
269-
self.step_tol = self.keys_da['iteration'][i][1]
270-
if opt == 'lambda':
271-
self.lam = self.keys_da['iteration'][i][1]
272-
if opt == 'lambda_max':
273-
self.lam_max = self.keys_da['iteration'][i][1]
274-
if opt == 'lambda_min':
275-
self.lam_min = self.keys_da['iteration'][i][1]
276-
if opt == 'lambda_factor':
277-
self.gamma = self.keys_da['iteration'][i][1]
278-
279-
if 'energy' in self.keys_da:
280-
# initial energy (Remember to extract this)
281-
self.trunc_energy = self.keys_da['energy']
282-
if self.trunc_energy > 1: # ensure that it is given as percentage
268+
if self.trunc_energy > 1: # ensure that it is given as percentage
283269
self.trunc_energy /= 100.
284270

285271

@@ -593,33 +579,19 @@ def _ext_iter_param(self):
593579
file. These parameters include convergence tolerances and parameters for the damping parameter. Default
594580
values for these parameters have been given here, if they are not provided in ITERATION.
595581
"""
596-
597-
# Predefine all the default values
598-
self.data_misfit_tol = 0.01
599-
self.step_tol = 0.01
600-
self.gamma = 0.2
601-
self.gamma_max = 0.5
602-
self.gamma_factor = 2.5
603-
self.trunc_energy = 0.95
604-
self.iteration = 0
605-
606-
# Loop over options in ITERATION and extract the parameters we want
607-
for i, opt in enumerate(list(zip(*self.keys_da['iteration']))[0]):
608-
if opt == 'data_misfit_tol':
609-
self.data_misfit_tol = self.keys_da['iteration'][i][1]
610-
if opt == 'step_tol':
611-
self.step_tol = self.keys_da['iteration'][i][1]
612-
if opt == 'gamma':
613-
self.gamma = self.keys_da['iteration'][i][1]
614-
if opt == 'gamma_max':
615-
self.gamma_max = self.keys_da['iteration'][i][1]
616-
if opt == 'gamma_factor':
617-
self.gamma_factor = self.keys_da['iteration'][i][1]
618-
619-
if 'energy' in self.keys_da:
620-
# initial energy (Remember to extract this)
621-
self.trunc_energy = self.keys_da['energy']
622-
if self.trunc_energy > 1: # ensure that it is given as percentage
582+
try:
583+
options = dict(self.keys_da['iteration'])
584+
except:
585+
options = dict([self.keys_da['iteration']])
586+
587+
self.data_misfit_tol = options.get('data_misfit_tol', 0.01)
588+
self.trunc_energy = options.get('energy', 0.95)
589+
self.step_tol = options.get('step_tol', 0.01)
590+
self.gamma = options.get('gamma', 0.2)
591+
self.gamma_max = options.get('gamma_max', 0.5)
592+
self.gamma_factor = options.get('gamma_factor', 2.5)
593+
594+
if self.trunc_energy > 1: # ensure that it is given as percentage
623595
self.trunc_energy /= 100.
624596

625597

pipt/update_schemes/esmda.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, keys_da, keys_en, sim):
3131
3232
Parameters
3333
----------
34-
keys_da['mda'] : list
34+
keys_da['mda'] : dict
3535
- tot_assim_steps: total number of iterations in MDA, e.g., 3
3636
- inflation_param: covariance inflation factors, e.g., [2, 4, 4]
3737
@@ -222,17 +222,16 @@ def _ext_inflation_param(self):
222222
alpha: list
223223
Data covariance inflation factor
224224
"""
225-
# Make sure MDA is a list
226-
if not isinstance(self.keys_da['mda'][0], list):
227-
mda_opts = [self.keys_da['mda']]
228-
else:
229-
mda_opts = self.keys_da['mda']
225+
try:
226+
mda_opts = dict(self.keys_da['mda'])
227+
except:
228+
mda_opts = dict([self.keys_da['mda']])
230229

231230
# Check if INFLATION_PARAM has been provided, and if so, extract the value(s). If not, we set alpha to the
232231
# default value equal to the tot. no. assim. steps
233-
if 'inflation_param' in list(zip(*mda_opts))[0]:
232+
if 'inflation_param' in mda_opts:
234233
# Extract value
235-
alpha_tmp = [item[1] for item in mda_opts if item[0] == 'inflation_param'][0]
234+
alpha_tmp = mda_opts['inflation_param']
236235

237236
# If one value is given, we copy it to all assim. steps. If multiple values are given, we check the
238237
# number of parameters corresponds to tot. no. assim. steps
@@ -279,22 +278,17 @@ def _ext_assim_steps(self):
279278
- ST 7/6-16
280279
- ST 1/3-17: Changed to output list of assim. steps instead of just tot. assim. steps
281280
"""
282-
# Make sure MDA is a list
283-
if not isinstance(self.keys_da['mda'][0], list):
284-
mda_opts = [self.keys_da['mda']]
285-
else:
286-
mda_opts = self.keys_da['mda']
281+
try:
282+
mda_opts = dict(self.keys_da['mda'])
283+
except:
284+
mda_opts = dict([self.keys_da['mda']])
287285

286+
288287
# Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
289-
assert 'tot_assim_steps' in list(
290-
zip(*mda_opts))[0], 'TOT_ASSIM_STEPS has not been given in MDA!'
291-
292-
# Extract max. iter
293-
tot_no_assim = int([item[1]
294-
for item in mda_opts if item[0] == 'tot_assim_steps'][0])
295-
296-
# Make a list of assim. steps
297-
assim_steps = list(range(tot_no_assim))
288+
try:
289+
assim_steps = list(range(int(mda_opts['tot_assim_steps'])))
290+
except KeyError:
291+
raise AssertionError('TOT_ASSIM_STEPS has not been given in MDA!')
298292

299293
# If it is a restart run, we remove simulations already done
300294
if self.restart is True:

popt/cost_functions/ecalc_npv.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def ecalc_npv(pred_data, **kwargs):
4444
report = kwargs.get('true_order', [])
4545

4646
# Economic values
47-
npv_const = {}
48-
for name, value in keys_opt['npv_const']:
49-
npv_const[name] = value
47+
npv_const = dict(keys_opt['npv_const'])
5048

5149
# Collect production data
5250
Qop = []

popt/cost_functions/ecalc_pareto_npv.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def ecalc_pareto_npv(pred_data, kwargs):
4646
report = kwargs.get('true_order', [])
4747

4848
# Economic values
49-
npv_const = {}
50-
for name, value in keys_opt['npv_const']:
51-
npv_const[name] = value
49+
npv_const = dict(keys_opt['npv_const'])
5250

5351
# Collect production data
5452
Qop = []

popt/cost_functions/npv.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def npv(pred_data, **kwargs):
3838
report = kwargs.get('true_order', [])
3939

4040
# Economic values
41-
npv_const = {}
42-
for name, value in keys_opt['npv_const']:
43-
npv_const[name] = value
41+
npv_const = dict(keys_opt['npv_const'])
4442

4543
values = []
4644
for i in np.arange(1, len(pred_data)):

popt/cost_functions/ren_npv.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def ren_npv(pred_data, kwargs):
3232
report = kwargs.get('true_order', [])
3333

3434
# Economic values
35-
npv_const = {}
36-
for name, value in keys_opt['npv_const']:
37-
npv_const[name] = value
35+
npv_const = dict(keys_opt['npv_const'])
3836

3937
# Loop over timesteps
4038
values = []

0 commit comments

Comments
 (0)