diff --git a/CHANGELOG.md b/CHANGELOG.md index e992115..cf3d161 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CHANGELOG.md to document notable changes. ### Changed +- Configs are now merged using `edflow.util.merge`, which allows for slimmer config definitions. - Changed usage from tensorboardX to tensorboard, due to native intergration in pytorch. - EvalPipeline defaults to keypath/labels for finding labels. - A `datasets` dict is now preferred over `dataset` and `validation_dataset` (backwards compatible default: `dataset` -> `datasets/train` and `validation_dataset` -> `datasets/validation`). diff --git a/edflow/edflow b/edflow/edflow index a763005..68501d8 100644 --- a/edflow/edflow +++ b/edflow/edflow @@ -1,8 +1,10 @@ #!/usr/bin/env python3 import os + # directly terminate on broken pipe without throwing exception import signal + signal.signal(signal.SIGPIPE, signal.SIG_DFL) import sys # noqa @@ -17,16 +19,23 @@ from edflow.main import train, test # noqa from edflow.custom_logging import run, get_logger # noqa from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint # noqa from edflow.config import parse_unknown_args, update_config -from edflow.util import retrieve +from edflow.util import retrieve, merge def load_config(base_configs, additional_kwargs): - config = dict() + """ + Loads configs using `yaml`, merges them and adds additional ``kwargs``. + """ + + configs = [] if base_configs: for base in base_configs: with open(base) as f: - config.update(yaml.full_load(f)) + configs += [yaml.full_load(f)] + + config = merge(*configs) update_config(config, additional_kwargs) + return config @@ -46,9 +55,8 @@ def main(opt, additional_kwargs): config["model"] = "edflow.util.NoModel" run_kwargs = { - "git": retrieve(config, "integrations/git", - default=False), - "log_level": opt.log_level + "git": retrieve(config, "integrations/git", default=False), + "log_level": opt.log_level, } # Project manager: use existing project or set up new project if opt.project is not None: @@ -73,8 +81,7 @@ def main(opt, additional_kwargs): name = config.get("experiment_name", None) if opt.name is not None: name = opt.name - run.init(log_dir="logs", code_root=code_root, postfix=name, - **run_kwargs) + run.init(log_dir="logs", code_root=code_root, postfix=name, **run_kwargs) # Logger logger = get_logger("main") @@ -117,9 +124,7 @@ def main(opt, additional_kwargs): ) ) - logger.info( - "Evaluation config:\n{}".format(yaml.dump(config)) - ) + logger.info("Evaluation config:\n{}".format(yaml.dump(config))) test(config, run.latest_eval, checkpoint, debug=opt.debug) @@ -154,7 +159,9 @@ if __name__ == "__main__": default="info", help="set the logging level.", ) - parser.add_argument("-d", "--debug", action="store_true", help="enable post-mortem debugging") + parser.add_argument( + "-d", "--debug", action="store_true", help="enable post-mortem debugging" + ) opt, unknown = parser.parse_known_args() additional_kwargs = parse_unknown_args(unknown) diff --git a/edflow/util.py b/edflow/util.py index f281fb4..4ae2d31 100644 --- a/edflow/util.py +++ b/edflow/util.py @@ -6,6 +6,7 @@ import pickle from typing import * import importlib +from copy import deepcopy try: from IPython import get_ipython @@ -447,14 +448,14 @@ def set_value(list_or_dict, key, val, splitval="/"): Parameters ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. key : str ``key/to/value``, path like string describing all keys necessary to consider to get to the desired value. List indices can also be passed here. value : object Anything you want to put behind :attr:`key` - list_or_dict : list or dict - Possibly nested list or dictionary. splitval : str String that defines the delimiter between keys of the different depth levels in :attr:`key`. @@ -597,12 +598,54 @@ def contains_key(nested_thing, key, splitval="/", expand=True): def update(to_update, to_update_with, splitval="/", expand=True): + """ + Updates the nested object :attr:`to_update` using the entries in :attr:`to_update_with`. + + Arguments + --------- + + to_update : list or dict + Possibly nested list or dictionary. Values in this object will be + overwritten. + to_update_with : list or dict + Possibly nested list or dictionary. + splitval : str + String that defines the delimiter between keys of the different depth + levels in :attr:`key`. + expand : bool + Whether to expand callable nodes on the path or not. + """ + def _update(key, value): set_value(to_update, key, value, splitval=splitval) walk(to_update_with, _update, splitval=splitval, pass_key=True) +def merge(*object_list, **update_kwargs): + """ + Merge nested objects using edflow's own :func:`edflow.util.update` function. + + Arguments + --------- + object_list : positional arguments + A list of nested dicts or lists. + update_kwargs : keyword arguments + Keyword arguments passed to :func:`edflow.util.update`. + + Returns + ------- + merged : nested dict or list + A nested config object + """ + + merged = deepcopy(object_list[0]) + for c in object_list[1:]: + update(merged, c, **update_kwargs) + + return merged + + def get_leaf_names(nested_thing): class LeafGetter: def __call__(self, key, value): diff --git a/tests/test_util.py b/tests/test_util.py index 41d30a6..81f8f6e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,6 +6,8 @@ walk, set_default, contains_key, + update, + merge, KeyNotFoundError, get_leaf_names, ) @@ -753,7 +755,46 @@ def test_set_default_key_not_contained(): assert val == "new" -# =================== set_default ================ +# =================== update ================ + + +def test_update(): + dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2} + ref = {"a": [1, 2, None, 4], "b": {"c": {"d": 1}}, "e": 2} + + u = {"a/3": 4} + + update(dol, u) + + assert dol == ref + + +def test_update_fancy_inject(): + dol = [1, 2, 3] + ref = [{"a": 1}, 2, 3] + + u = {"0/a": 1} + + update(dol, u) + + assert dol == ref + + +# =================== merge ================ + + +def test_merge(): + dol1 = {"b": {"c": {"d": 1}}} + dol2 = {"b": {"c": {"e": 2}}} + + ref = {"b": {"c": {"d": 1, "e": 2}}} + + value = merge(dol1, dol2) + + assert value == ref + + +# =================== contains_key ================ def test_contains_key():