Skip to content

Commit 2e8c9eb

Browse files
authored
Merge pull request #57 from cseptesting/53-correct-management-of-models-input-catalog-and-arg-files
Update management of input files
2 parents e85b038 + 9b674ef commit 2e8c9eb

29 files changed

+1109
-653
lines changed

floatcsep/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from floatcsep import experiment
33
from floatcsep import model
44
from floatcsep.infrastructure import engine, environments, registries, repositories, logger
5-
from floatcsep.utils import readers, accessors, helpers
5+
from floatcsep.utils import file_io, accessors, helpers
66
from floatcsep.postprocess import reporting, plot_handler
77

88
from importlib.metadata import version, PackageNotFoundError

floatcsep/evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def plot_results(
287287
# Regular consistency/comparative test plots (e.g., many models)
288288
try:
289289
for time_str in timewindow:
290-
fig_path = registry.get_figure(time_str, self.name)
290+
fig_path = registry.get_figure_key(time_str, self.name)
291291
results = self.read_results(time_str, models)
292292
ax = func(results, plot_args=fargs, **fkwargs)
293293
if "code" in fargs:
@@ -307,7 +307,7 @@ def plot_results(
307307
registry.figures[time_str][fig_name] = os.path.join(
308308
time_str, "figures", fig_name
309309
)
310-
fig_path = registry.get_figure(time_str, fig_name)
310+
fig_path = registry.get_figure_key(time_str, fig_name)
311311
ax = func(result, plot_args=fargs, **fkwargs, show=False)
312312
if "code" in fargs:
313313
exec(fargs["code"])
@@ -318,7 +318,7 @@ def plot_results(
318318
pyplot.show()
319319

320320
elif self.type in ["sequential", "sequential_comparative", "batch"]:
321-
fig_path = registry.get_figure(timewindow[-1], self.name)
321+
fig_path = registry.get_figure_key(timewindow[-1], self.name)
322322
results = self.read_results(timewindow[-1], models)
323323
ax = func(results, plot_args=fargs, **fkwargs)
324324

floatcsep/experiment.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
parse_nested_dicts,
2626
)
2727
from floatcsep.infrastructure.engine import Task, TaskGraph
28+
from floatcsep.infrastructure.logger import log_models_tree, log_results_tree
2829

2930
log = logging.getLogger("floatLogger")
3031

@@ -52,8 +53,8 @@ class Experiment:
5253
- growth (:class:`str`): `incremental` or `cumulative`
5354
- offset (:class:`float`): recurrence of forecast creation.
5455
55-
For further details, see :func:`~floatcsep.utils.timewindows_ti`
56-
and :func:`~floatcsep.utils.timewindows_td`
56+
For further details, see :func:`~floatcsep.utils.time_windows_ti`
57+
and :func:`~floatcsep.utils.time_windows_td`
5758
5859
region_config (dict): Contains all the spatial and magnitude
5960
specifications. It must contain the following keys:
@@ -75,6 +76,7 @@ class Experiment:
7576
7677
model_config (str): Path to the models' configuration file
7778
test_config (str): Path to the evaluations' configuration file
79+
run_mode (str): 'sequential' or 'parallel'
7880
default_test_kwargs (dict): Default values for the testing
7981
(seed, number of simulations, etc.)
8082
postprocess (dict): Contains the instruction for postprocessing
@@ -99,6 +101,7 @@ def __init__(
99101
postprocess: str = None,
100102
default_test_kwargs: dict = None,
101103
rundir: str = "results",
104+
run_mode: str = "sequential",
102105
report_hook: dict = None,
103106
**kwargs,
104107
) -> None:
@@ -118,14 +121,15 @@ def __init__(
118121
os.makedirs(os.path.join(workdir, rundir), exist_ok=True)
119122

120123
self.name = name if name else "floatingExp"
121-
self.registry = ExperimentRegistry(workdir, rundir)
124+
self.registry = ExperimentRegistry.factory(workdir=workdir, run_dir=rundir)
122125
self.results_repo = ResultsRepository(self.registry)
123126
self.catalog_repo = CatalogRepository(self.registry)
124127

125128
self.config_file = kwargs.get("config_file", None)
126129
self.original_config = kwargs.get("original_config", None)
127130
self.original_run_dir = kwargs.get("original_rundir", None)
128131
self.run_dir = rundir
132+
self.run_mode = run_mode
129133
self.seed = kwargs.get("seed", None)
130134
self.time_config = read_time_cfg(time_config, **kwargs)
131135
self.region_config = read_region_cfg(region_config, **kwargs)
@@ -143,7 +147,7 @@ def __init__(
143147
log.info(f"Setting up experiment {self.name}:")
144148
log.info(f"\tStart: {self.start_date}")
145149
log.info(f"\tEnd: {self.end_date}")
146-
log.info(f"\tTime windows: {len(self.timewindows)}")
150+
log.info(f"\tTime windows: {len(self.time_windows)}")
147151
log.info(f"\tRegion: {self.region.name if self.region else None}")
148152
log.info(
149153
f"\tMagnitude range: [{numpy.min(self.magnitudes)},"
@@ -175,7 +179,7 @@ def __getattr__(self, item: str) -> object:
175179
Override built-in method to return the experiment attributes by also using the command
176180
``experiment.{attr}``. Adds also to the experiment scope the keys of
177181
:attr:`region_config` or :attr:`time_config`. These are: ``start_date``, ``end_date``,
178-
``timewindows``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``,
182+
``time_windows``, ``horizon``, ``offset``, ``region``, ``magnitudes``, ``mag_min``,
179183
`mag_max``, ``mag_bin``, ``depth_min`` depth_max .
180184
"""
181185

@@ -295,8 +299,8 @@ def stage_models(self) -> None:
295299
"""
296300
log.info("Staging models")
297301
for i in self.models:
298-
i.stage(self.timewindows)
299-
self.registry.add_forecast_registry(i)
302+
i.stage(self.time_windows, run_mode=self.run_mode, run_dir=self.run_dir)
303+
self.registry.add_model_registry(i)
300304

301305
def set_tests(self, test_config: Union[str, Dict, List]) -> list:
302306
"""
@@ -376,17 +380,17 @@ def set_tasks(self) -> None:
376380
"""
377381

378382
# Set the file path structure
379-
self.registry.build_tree(self.timewindows, self.models, self.tests)
383+
self.registry.build_tree(self.time_windows, self.models, self.tests, self.run_mode)
380384

381385
log.debug("Pre-run forecast summary")
382-
self.registry.log_forecast_trees(self.timewindows)
386+
log_models_tree(log, self.registry, self.time_windows)
383387
log.debug("Pre-run result summary")
384-
self.registry.log_results_tree()
388+
log_results_tree(log, self.registry)
385389

386390
log.info("Setting up experiment's tasks")
387391

388392
# Get the time windows strings
389-
tw_strings = timewindow2str(self.timewindows)
393+
tw_strings = timewindow2str(self.time_windows)
390394

391395
# Prepare the testing catalogs
392396
task_graph = TaskGraph()
@@ -481,7 +485,7 @@ def set_tasks(self) -> None:
481485
)
482486
# Set up the Sequential_Comparative Scores
483487
elif test_k.type == "sequential_comparative":
484-
tw_strs = timewindow2str(self.timewindows)
488+
tw_strs = timewindow2str(self.time_windows)
485489
for model_j in self.models:
486490
task_k = Task(
487491
instance=test_k,
@@ -504,7 +508,7 @@ def set_tasks(self) -> None:
504508
)
505509
# Set up the Batch comparative Scores
506510
elif test_k.type == "batch":
507-
time_str = timewindow2str(self.timewindows[-1])
511+
time_str = timewindow2str(self.time_windows[-1])
508512
for model_j in self.models:
509513
task_k = Task(
510514
instance=test_k,
@@ -540,9 +544,9 @@ def run(self) -> None:
540544
self.task_graph.run()
541545
log.info("Calculation completed")
542546
log.debug("Post-run forecast registry")
543-
self.registry.log_forecast_trees(self.timewindows)
547+
log_models_tree(log, self.registry, self.time_windows)
544548
log.debug("Post-run result summary")
545-
self.registry.log_results_tree()
549+
log_results_tree(log, self.registry)
546550

547551
def read_results(self, test: Evaluation, window: str) -> List:
548552
"""
@@ -559,7 +563,7 @@ def make_repr(self) -> None:
559563
560564
"""
561565
log.info("Creating reproducibility config file")
562-
repr_config = self.registry.get("repr_config")
566+
repr_config = self.registry.get_attr("repr_config")
563567

564568
# Dropping region to results folder if it is a file
565569
region_path = self.region_config.get("path", False)
@@ -604,7 +608,7 @@ def as_dict(self, extra: Sequence = (), extended=False) -> dict:
604608
"time_config": {
605609
i: j
606610
for i, j in self.time_config.items()
607-
if (i not in ("timewindows",) or extended)
611+
if (i not in ("time_windows",) or extended)
608612
},
609613
"region_config": {
610614
i: j
@@ -731,7 +735,7 @@ def test_stat(test_orig, test_repr):
731735

732736
def get_results(self):
733737

734-
win_orig = timewindow2str(self.original.timewindows)
738+
win_orig = timewindow2str(self.original.time_windows)
735739

736740
tests_orig = self.original.tests
737741

@@ -787,7 +791,7 @@ def get_hash(filename):
787791

788792
def get_filecomp(self):
789793

790-
win_orig = timewindow2str(self.original.timewindows)
794+
win_orig = timewindow2str(self.original.time_windows)
791795

792796
tests_orig = self.original.tests
793797

@@ -801,8 +805,8 @@ def get_filecomp(self):
801805
for tw in win_orig:
802806
results[test.name][tw] = dict.fromkeys(models_orig)
803807
for model in models_orig:
804-
orig_path = self.original.registry.get_result(tw, test, model)
805-
repr_path = self.reproduced.registry.get_result(tw, test, model)
808+
orig_path = self.original.registry.get_result_key(tw, test, model)
809+
repr_path = self.reproduced.registry.get_result_key(tw, test, model)
806810

807811
results[test.name][tw][model] = {
808812
"hash": (self.get_hash(orig_path) == self.get_hash(repr_path)),
@@ -811,8 +815,8 @@ def get_filecomp(self):
811815
else:
812816
results[test.name] = dict.fromkeys(models_orig)
813817
for model in models_orig:
814-
orig_path = self.original.registry.get_result(win_orig[-1], test, model)
815-
repr_path = self.reproduced.registry.get_result(win_orig[-1], test, model)
818+
orig_path = self.original.registry.get_result_key(win_orig[-1], test, model)
819+
repr_path = self.reproduced.registry.get_result_key(win_orig[-1], test, model)
816820
results[test.name][model] = {
817821
"hash": (self.get_hash(orig_path) == self.get_hash(repr_path)),
818822
"byte2byte": filecmp.cmp(orig_path, repr_path),

floatcsep/infrastructure/logger.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,74 @@ def set_console_log_level(log_level):
6060
for handler in logger.handlers:
6161
if isinstance(handler, logging.StreamHandler):
6262
handler.setLevel(log_level)
63+
64+
65+
66+
67+
def log_models_tree(log, experiment_registry, time_windows):
68+
"""
69+
Logs the forecasts for all models managed by this ExperimentFileRegistry.
70+
"""
71+
log.debug("===================")
72+
log.debug(f" Total Time Windows: {len(time_windows)}")
73+
for model_name, registry in experiment_registry.model_registries.items():
74+
log.debug(f" Model: {model_name}")
75+
exists_group = []
76+
not_exist_group = []
77+
78+
for timewindow, filepath in registry.forecasts.items():
79+
if registry.forecast_exists(timewindow):
80+
exists_group.append(timewindow)
81+
else:
82+
not_exist_group.append(timewindow)
83+
84+
log.debug(f" Existing forecasts: {len(exists_group)}")
85+
log.debug(f" Missing forecasts: {len(not_exist_group)}")
86+
for timewindow in not_exist_group:
87+
log.debug(f" Time Window: {timewindow}")
88+
log.debug("===================")
89+
90+
91+
def log_results_tree(log, experiment_registry):
92+
"""
93+
Logs a summary of the results dictionary, sorted by test.
94+
For each test and time window, it logs whether all models have results,
95+
or if some results are missing, and specifies which models are missing.
96+
"""
97+
log.debug("===================")
98+
99+
total_results = results_exist_count = results_not_exist_count = 0
100+
101+
# Get all unique test names and sort them
102+
all_tests = sorted(
103+
{test_name for tests in experiment_registry.results.values() for test_name in tests}
104+
)
105+
106+
for test_name in all_tests:
107+
log.debug(f"Test: {test_name}")
108+
for timewindow, tests in experiment_registry.results.items():
109+
if test_name in tests:
110+
models = tests[test_name]
111+
missing_models = []
112+
113+
for model_name, result_path in models.items():
114+
total_results += 1
115+
result_full_path = experiment_registry.get_result_key(timewindow, test_name, model_name)
116+
if os.path.exists(result_full_path):
117+
results_exist_count += 1
118+
else:
119+
results_not_exist_count += 1
120+
missing_models.append(model_name)
121+
122+
if not missing_models:
123+
log.debug(f" Time Window: {timewindow} - All models evaluated.")
124+
else:
125+
log.debug(
126+
f" Time Window: {timewindow} - Missing results for models: "
127+
f"{', '.join(missing_models)}"
128+
)
129+
130+
log.debug(f"Total Results: {total_results}")
131+
log.debug(f"Results that Exist: {results_exist_count}")
132+
log.debug(f"Results that Do Not Exist: {results_not_exist_count}")
133+
log.debug("===================")

0 commit comments

Comments
 (0)