25
25
parse_nested_dicts ,
26
26
)
27
27
from floatcsep .infrastructure .engine import Task , TaskGraph
28
+ from floatcsep .infrastructure .logger import log_models_tree , log_results_tree
28
29
29
30
log = logging .getLogger ("floatLogger" )
30
31
@@ -52,8 +53,8 @@ class Experiment:
52
53
- growth (:class:`str`): `incremental` or `cumulative`
53
54
- offset (:class:`float`): recurrence of forecast creation.
54
55
55
- For further details, see :func:`~floatcsep.utils.timewindows_ti `
56
- and :func:`~floatcsep.utils.timewindows_td `
56
+ For further details, see :func:`~floatcsep.utils.time_windows_ti `
57
+ and :func:`~floatcsep.utils.time_windows_td `
57
58
58
59
region_config (dict): Contains all the spatial and magnitude
59
60
specifications. It must contain the following keys:
@@ -75,6 +76,7 @@ class Experiment:
75
76
76
77
model_config (str): Path to the models' configuration file
77
78
test_config (str): Path to the evaluations' configuration file
79
+ run_mode (str): 'sequential' or 'parallel'
78
80
default_test_kwargs (dict): Default values for the testing
79
81
(seed, number of simulations, etc.)
80
82
postprocess (dict): Contains the instruction for postprocessing
@@ -99,6 +101,7 @@ def __init__(
99
101
postprocess : str = None ,
100
102
default_test_kwargs : dict = None ,
101
103
rundir : str = "results" ,
104
+ run_mode : str = "sequential" ,
102
105
report_hook : dict = None ,
103
106
** kwargs ,
104
107
) -> None :
@@ -118,14 +121,15 @@ def __init__(
118
121
os .makedirs (os .path .join (workdir , rundir ), exist_ok = True )
119
122
120
123
self .name = name if name else "floatingExp"
121
- self .registry = ExperimentRegistry (workdir , rundir )
124
+ self .registry = ExperimentRegistry . factory (workdir = workdir , run_dir = rundir )
122
125
self .results_repo = ResultsRepository (self .registry )
123
126
self .catalog_repo = CatalogRepository (self .registry )
124
127
125
128
self .config_file = kwargs .get ("config_file" , None )
126
129
self .original_config = kwargs .get ("original_config" , None )
127
130
self .original_run_dir = kwargs .get ("original_rundir" , None )
128
131
self .run_dir = rundir
132
+ self .run_mode = run_mode
129
133
self .seed = kwargs .get ("seed" , None )
130
134
self .time_config = read_time_cfg (time_config , ** kwargs )
131
135
self .region_config = read_region_cfg (region_config , ** kwargs )
@@ -143,7 +147,7 @@ def __init__(
143
147
log .info (f"Setting up experiment { self .name } :" )
144
148
log .info (f"\t Start: { self .start_date } " )
145
149
log .info (f"\t End: { self .end_date } " )
146
- log .info (f"\t Time windows: { len (self .timewindows )} " )
150
+ log .info (f"\t Time windows: { len (self .time_windows )} " )
147
151
log .info (f"\t Region: { self .region .name if self .region else None } " )
148
152
log .info (
149
153
f"\t Magnitude range: [{ numpy .min (self .magnitudes )} ,"
@@ -175,7 +179,7 @@ def __getattr__(self, item: str) -> object:
175
179
Override built-in method to return the experiment attributes by also using the command
176
180
``experiment.{attr}``. Adds also to the experiment scope the keys of
177
181
:attr:`region_config` or :attr:`time_config`. These are: ``start_date``, ``end_date``,
178
- ``timewindows ``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``,
182
+ ``time_windows ``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``,
179
183
`mag_max``, ``mag_bin``, ``depth_min`` depth_max .
180
184
"""
181
185
@@ -295,8 +299,8 @@ def stage_models(self) -> None:
295
299
"""
296
300
log .info ("Staging models" )
297
301
for i in self .models :
298
- i .stage (self .timewindows )
299
- self .registry .add_forecast_registry (i )
302
+ i .stage (self .time_windows , run_mode = self . run_mode , run_dir = self . run_dir )
303
+ self .registry .add_model_registry (i )
300
304
301
305
def set_tests (self , test_config : Union [str , Dict , List ]) -> list :
302
306
"""
@@ -376,17 +380,17 @@ def set_tasks(self) -> None:
376
380
"""
377
381
378
382
# Set the file path structure
379
- self .registry .build_tree (self .timewindows , self .models , self .tests )
383
+ self .registry .build_tree (self .time_windows , self .models , self .tests , self . run_mode )
380
384
381
385
log .debug ("Pre-run forecast summary" )
382
- self .registry . log_forecast_trees ( self .timewindows )
386
+ log_models_tree ( log , self .registry , self .time_windows )
383
387
log .debug ("Pre-run result summary" )
384
- self .registry . log_results_tree ( )
388
+ log_results_tree ( log , self .registry )
385
389
386
390
log .info ("Setting up experiment's tasks" )
387
391
388
392
# Get the time windows strings
389
- tw_strings = timewindow2str (self .timewindows )
393
+ tw_strings = timewindow2str (self .time_windows )
390
394
391
395
# Prepare the testing catalogs
392
396
task_graph = TaskGraph ()
@@ -481,7 +485,7 @@ def set_tasks(self) -> None:
481
485
)
482
486
# Set up the Sequential_Comparative Scores
483
487
elif test_k .type == "sequential_comparative" :
484
- tw_strs = timewindow2str (self .timewindows )
488
+ tw_strs = timewindow2str (self .time_windows )
485
489
for model_j in self .models :
486
490
task_k = Task (
487
491
instance = test_k ,
@@ -504,7 +508,7 @@ def set_tasks(self) -> None:
504
508
)
505
509
# Set up the Batch comparative Scores
506
510
elif test_k .type == "batch" :
507
- time_str = timewindow2str (self .timewindows [- 1 ])
511
+ time_str = timewindow2str (self .time_windows [- 1 ])
508
512
for model_j in self .models :
509
513
task_k = Task (
510
514
instance = test_k ,
@@ -540,9 +544,9 @@ def run(self) -> None:
540
544
self .task_graph .run ()
541
545
log .info ("Calculation completed" )
542
546
log .debug ("Post-run forecast registry" )
543
- self .registry . log_forecast_trees ( self .timewindows )
547
+ log_models_tree ( log , self .registry , self .time_windows )
544
548
log .debug ("Post-run result summary" )
545
- self .registry . log_results_tree ( )
549
+ log_results_tree ( log , self .registry )
546
550
547
551
def read_results (self , test : Evaluation , window : str ) -> List :
548
552
"""
@@ -559,7 +563,7 @@ def make_repr(self) -> None:
559
563
560
564
"""
561
565
log .info ("Creating reproducibility config file" )
562
- repr_config = self .registry .get ("repr_config" )
566
+ repr_config = self .registry .get_attr ("repr_config" )
563
567
564
568
# Dropping region to results folder if it is a file
565
569
region_path = self .region_config .get ("path" , False )
@@ -604,7 +608,7 @@ def as_dict(self, extra: Sequence = (), extended=False) -> dict:
604
608
"time_config" : {
605
609
i : j
606
610
for i , j in self .time_config .items ()
607
- if (i not in ("timewindows " ,) or extended )
611
+ if (i not in ("time_windows " ,) or extended )
608
612
},
609
613
"region_config" : {
610
614
i : j
@@ -731,7 +735,7 @@ def test_stat(test_orig, test_repr):
731
735
732
736
def get_results (self ):
733
737
734
- win_orig = timewindow2str (self .original .timewindows )
738
+ win_orig = timewindow2str (self .original .time_windows )
735
739
736
740
tests_orig = self .original .tests
737
741
@@ -787,7 +791,7 @@ def get_hash(filename):
787
791
788
792
def get_filecomp (self ):
789
793
790
- win_orig = timewindow2str (self .original .timewindows )
794
+ win_orig = timewindow2str (self .original .time_windows )
791
795
792
796
tests_orig = self .original .tests
793
797
@@ -801,8 +805,8 @@ def get_filecomp(self):
801
805
for tw in win_orig :
802
806
results [test .name ][tw ] = dict .fromkeys (models_orig )
803
807
for model in models_orig :
804
- orig_path = self .original .registry .get_result (tw , test , model )
805
- repr_path = self .reproduced .registry .get_result (tw , test , model )
808
+ orig_path = self .original .registry .get_result_key (tw , test , model )
809
+ repr_path = self .reproduced .registry .get_result_key (tw , test , model )
806
810
807
811
results [test .name ][tw ][model ] = {
808
812
"hash" : (self .get_hash (orig_path ) == self .get_hash (repr_path )),
@@ -811,8 +815,8 @@ def get_filecomp(self):
811
815
else :
812
816
results [test .name ] = dict .fromkeys (models_orig )
813
817
for model in models_orig :
814
- orig_path = self .original .registry .get_result (win_orig [- 1 ], test , model )
815
- repr_path = self .reproduced .registry .get_result (win_orig [- 1 ], test , model )
818
+ orig_path = self .original .registry .get_result_key (win_orig [- 1 ], test , model )
819
+ repr_path = self .reproduced .registry .get_result_key (win_orig [- 1 ], test , model )
816
820
results [test .name ][model ] = {
817
821
"hash" : (self .get_hash (orig_path ) == self .get_hash (repr_path )),
818
822
"byte2byte" : filecmp .cmp (orig_path , repr_path ),
0 commit comments