From 7ed8530cddde29cbbf8a712ccf28f093b0fc4ca1 Mon Sep 17 00:00:00 2001 From: Johannes Haux Date: Thu, 16 Jan 2020 17:11:37 +0100 Subject: [PATCH 1/3] :tada: Merges configs using new merge functionality `edflow.util.merge` allows for slimmer config definitions. --- edflow/edflow | 31 +++++++++++++++++++------------ edflow/util.py | 47 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/edflow/edflow b/edflow/edflow index 413b0b9..00377bb 100644 --- a/edflow/edflow +++ b/edflow/edflow @@ -1,8 +1,10 @@ #!/usr/bin/env python3 import os + # directly terminate on broken pipe without throwing exception import signal + signal.signal(signal.SIGPIPE, signal.SIG_DFL) import sys # noqa @@ -16,16 +18,23 @@ from edflow.main import train, test # noqa from edflow.custom_logging import run, get_logger # noqa from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint # noqa from edflow.config import parse_unknown_args, update_config -from edflow.util import retrieve +from edflow.util import retrieve, merge def load_config(base_configs, additional_kwargs): - config = dict() + """ + Loads configs using `yaml`, merges them and adds additional ``kwargs``. + """ + + configs = [] if base_configs: for base in base_configs: with open(base) as f: - config.update(yaml.full_load(f)) + configs += [yaml.full_load(f)] + + config = merge(*configs) update_config(config, additional_kwargs) + return config @@ -35,9 +44,8 @@ def main(opt, additional_kwargs): config["model"] = "edflow.util.NoModel" run_kwargs = { - "git": retrieve(config, "integrations/git", - default=False), - "log_level": opt.log_level + "git": retrieve(config, "integrations/git", default=False), + "log_level": opt.log_level, } # Project manager: use existing project or set up new project if opt.project is not None: @@ -62,8 +70,7 @@ def main(opt, additional_kwargs): name = config.get("experiment_name", None) if opt.name is not None: name = opt.name - run.init(log_dir="logs", code_root=code_root, postfix=name, - **run_kwargs) + run.init(log_dir="logs", code_root=code_root, postfix=name, **run_kwargs) # Logger logger = get_logger("main") @@ -106,9 +113,7 @@ def main(opt, additional_kwargs): ) ) - logger.info( - "Evaluation config:\n{}".format(yaml.dump(config)) - ) + logger.info("Evaluation config:\n{}".format(yaml.dump(config))) test(config, run.latest_eval, checkpoint, debug=opt.debug) @@ -143,7 +148,9 @@ if __name__ == "__main__": default="info", help="set the logging level.", ) - parser.add_argument("-d", "--debug", action="store_true", help="enable post-mortem debugging") + parser.add_argument( + "-d", "--debug", action="store_true", help="enable post-mortem debugging" + ) opt, unknown = parser.parse_known_args() additional_kwargs = parse_unknown_args(unknown) diff --git a/edflow/util.py b/edflow/util.py index 1003a17..97ea899 100644 --- a/edflow/util.py +++ b/edflow/util.py @@ -7,6 +7,7 @@ from fastnumbers import fast_int from typing import * import importlib +from copy import deepcopy try: from IPython import get_ipython @@ -448,14 +449,14 @@ def set_value(list_or_dict, key, val, splitval="/"): Parameters ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. key : str ``key/to/value``, path like string describing all keys necessary to consider to get to the desired value. List indices can also be passed here. value : object Anything you want to put behind :attr:`key` - list_or_dict : list or dict - Possibly nested list or dictionary. splitval : str String that defines the delimiter between keys of the different depth levels in :attr:`key`. @@ -592,12 +593,54 @@ def contains_key(nested_thing, key, splitval="/", expand=True): def update(to_update, to_update_with, splitval="/", expand=True): + """ + Updates the nested object :attr:`to_update` using the entries in :attr:`to_update_with`. + + Arguments + --------- + + to_update : list or dict + Possibly nested list or dictionary. Values in this object will be + overwritten. + to_update_with : list or dict + Possibly nested list or dictionary. + splitval : str + String that defines the delimiter between keys of the different depth + levels in :attr:`key`. + expand : bool + Whether to expand callable nodes on the path or not. + """ + def _update(key, value): set_value(to_update, key, value, splitval=splitval) walk(to_update_with, _update, splitval=splitval, pass_key=True) +def merge(*object_list, **update_kwargs): + """ + Merge nested objects using edflow's own :func:`edflow.util.update` function. + + Arguments + --------- + object_list : positional arguments + A list of nested dicts or lists. + update_kwargs : keyword arguments + Keyword arguments passed to :func:`edflow.util.update`. + + Returns + ------- + merged : nested dict or list + A nested config object + """ + + merged = deepcopy(object_list[0]) + for c in object_list[1:]: + update(merged, c, **update_kwargs) + + return merged + + def get_leaf_names(nested_thing): class LeafGetter: def __call__(self, key, value): From 2a069eb4bf974efc2de673b843ccf9823572f15f Mon Sep 17 00:00:00 2001 From: Johannes Haux Date: Thu, 16 Jan 2020 17:12:52 +0100 Subject: [PATCH 2/3] :white_check_mark: Adds tests for update and merge --- tests/test_util.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index 41d30a6..81f8f6e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,6 +6,8 @@ walk, set_default, contains_key, + update, + merge, KeyNotFoundError, get_leaf_names, ) @@ -753,7 +755,46 @@ def test_set_default_key_not_contained(): assert val == "new" -# =================== set_default ================ +# =================== update ================ + + +def test_update(): + dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2} + ref = {"a": [1, 2, None, 4], "b": {"c": {"d": 1}}, "e": 2} + + u = {"a/3": 4} + + update(dol, u) + + assert dol == ref + + +def test_update_fancy_inject(): + dol = [1, 2, 3] + ref = [{"a": 1}, 2, 3] + + u = {"0/a": 1} + + update(dol, u) + + assert dol == ref + + +# =================== merge ================ + + +def test_merge(): + dol1 = {"b": {"c": {"d": 1}}} + dol2 = {"b": {"c": {"e": 2}}} + + ref = {"b": {"c": {"d": 1, "e": 2}}} + + value = merge(dol1, dol2) + + assert value == ref + + +# =================== contains_key ================ def test_contains_key(): From b31b7aa8bde485b15b5d284581348017fae09316 Mon Sep 17 00:00:00 2001 From: Johannes Haux Date: Thu, 16 Jan 2020 17:13:05 +0100 Subject: [PATCH 3/3] Changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b7688c..2a9e272 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CHANGELOG.md to document notable changes. ### Changed +- Configs are now merged using `edflow.util.merge`, which allows for slimmer config definitions. - Changed configuration of integrations: `EDFLOWGIT` now `integrations/git`, `wandb_logging` now `integrations/wandb`, `tensorboardX_logging` now `--integrations/tensorboardX`. - ProjectManager is now `edflow.run` and initialized with `edflow.run.init(...)`. - Saved config files use `-` instead of `:` in filename to be consistent.