Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- CHANGELOG.md to document notable changes.

### Changed
- Configs are now merged using `edflow.util.merge`, which allows for slimmer config definitions.
- Changed usage from tensorboardX to tensorboard, due to native intergration in pytorch.
- EvalPipeline defaults to keypath/labels for finding labels.
- A `datasets` dict is now preferred over `dataset` and `validation_dataset` (backwards compatible default: `dataset` -> `datasets/train` and `validation_dataset` -> `datasets/validation`).
Expand Down
31 changes: 19 additions & 12 deletions edflow/edflow
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python3

import os

# directly terminate on broken pipe without throwing exception
import signal

signal.signal(signal.SIGPIPE, signal.SIG_DFL)

import sys # noqa
Expand All @@ -17,16 +19,23 @@ from edflow.main import train, test # noqa
from edflow.custom_logging import run, get_logger # noqa
from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint # noqa
from edflow.config import parse_unknown_args, update_config
from edflow.util import retrieve
from edflow.util import retrieve, merge


def load_config(base_configs, additional_kwargs):
config = dict()
"""
Loads configs using `yaml`, merges them and adds additional ``kwargs``.
"""

configs = []
if base_configs:
for base in base_configs:
with open(base) as f:
config.update(yaml.full_load(f))
configs += [yaml.full_load(f)]

config = merge(*configs)
update_config(config, additional_kwargs)

return config


Expand All @@ -46,9 +55,8 @@ def main(opt, additional_kwargs):
config["model"] = "edflow.util.NoModel"

run_kwargs = {
"git": retrieve(config, "integrations/git",
default=False),
"log_level": opt.log_level
"git": retrieve(config, "integrations/git", default=False),
"log_level": opt.log_level,
}
# Project manager: use existing project or set up new project
if opt.project is not None:
Expand All @@ -73,8 +81,7 @@ def main(opt, additional_kwargs):
name = config.get("experiment_name", None)
if opt.name is not None:
name = opt.name
run.init(log_dir="logs", code_root=code_root, postfix=name,
**run_kwargs)
run.init(log_dir="logs", code_root=code_root, postfix=name, **run_kwargs)

# Logger
logger = get_logger("main")
Expand Down Expand Up @@ -117,9 +124,7 @@ def main(opt, additional_kwargs):
)
)

logger.info(
"Evaluation config:\n{}".format(yaml.dump(config))
)
logger.info("Evaluation config:\n{}".format(yaml.dump(config)))
test(config, run.latest_eval, checkpoint, debug=opt.debug)


Expand Down Expand Up @@ -154,7 +159,9 @@ if __name__ == "__main__":
default="info",
help="set the logging level.",
)
parser.add_argument("-d", "--debug", action="store_true", help="enable post-mortem debugging")
parser.add_argument(
"-d", "--debug", action="store_true", help="enable post-mortem debugging"
)

opt, unknown = parser.parse_known_args()
additional_kwargs = parse_unknown_args(unknown)
Expand Down
47 changes: 45 additions & 2 deletions edflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pickle
from typing import *
import importlib
from copy import deepcopy

try:
from IPython import get_ipython
Expand Down Expand Up @@ -447,14 +448,14 @@ def set_value(list_or_dict, key, val, splitval="/"):
Parameters
----------

list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
``key/to/value``, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be passed
here.
value : object
Anything you want to put behind :attr:`key`
list_or_dict : list or dict
Possibly nested list or dictionary.
splitval : str
String that defines the delimiter between keys of the different depth
levels in :attr:`key`.
Expand Down Expand Up @@ -597,12 +598,54 @@ def contains_key(nested_thing, key, splitval="/", expand=True):


def update(to_update, to_update_with, splitval="/", expand=True):
"""
Updates the nested object :attr:`to_update` using the entries in :attr:`to_update_with`.

Arguments
---------

to_update : list or dict
Possibly nested list or dictionary. Values in this object will be
overwritten.
to_update_with : list or dict
Possibly nested list or dictionary.
splitval : str
String that defines the delimiter between keys of the different depth
levels in :attr:`key`.
expand : bool
Whether to expand callable nodes on the path or not.
"""

def _update(key, value):
set_value(to_update, key, value, splitval=splitval)

walk(to_update_with, _update, splitval=splitval, pass_key=True)


def merge(*object_list, **update_kwargs):
"""
Merge nested objects using edflow's own :func:`edflow.util.update` function.

Arguments
---------
object_list : positional arguments
A list of nested dicts or lists.
update_kwargs : keyword arguments
Keyword arguments passed to :func:`edflow.util.update`.

Returns
-------
merged : nested dict or list
A nested config object
"""

merged = deepcopy(object_list[0])
for c in object_list[1:]:
update(merged, c, **update_kwargs)

return merged


def get_leaf_names(nested_thing):
class LeafGetter:
def __call__(self, key, value):
Expand Down
43 changes: 42 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
walk,
set_default,
contains_key,
update,
merge,
KeyNotFoundError,
get_leaf_names,
)
Expand Down Expand Up @@ -753,7 +755,46 @@ def test_set_default_key_not_contained():
assert val == "new"


# =================== set_default ================
# =================== update ================


def test_update():
dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}
ref = {"a": [1, 2, None, 4], "b": {"c": {"d": 1}}, "e": 2}

u = {"a/3": 4}

update(dol, u)

assert dol == ref


def test_update_fancy_inject():
dol = [1, 2, 3]
ref = [{"a": 1}, 2, 3]

u = {"0/a": 1}

update(dol, u)

assert dol == ref


# =================== merge ================


def test_merge():
dol1 = {"b": {"c": {"d": 1}}}
dol2 = {"b": {"c": {"e": 2}}}

ref = {"b": {"c": {"d": 1, "e": 2}}}

value = merge(dol1, dol2)

assert value == ref


# =================== contains_key ================


def test_contains_key():
Expand Down