diff --git a/fast_plotter/__main__.py b/fast_plotter/__main__.py index 54221e5..dc03c7a 100644 --- a/fast_plotter/__main__.py +++ b/fast_plotter/__main__.py @@ -140,12 +140,12 @@ def process_one_file(infile, args): def dress_main_plots(plots, annotations=[], yscale=None, ylabel=None, legend={}, limits={}, xtickrotation=None, **kwargs): for main_ax, summary_ax in plots.values(): - add_annotations(annotations, main_ax) + add_annotations(annotations, main_ax, summary_ax) if yscale: main_ax.set_yscale(yscale) if ylabel: main_ax.set_ylabel(ylabel) - main_ax.legend(**legend) + main_ax.legend(**legend).set_zorder(20) main_ax.grid(True) main_ax.set_axisbelow(True) for axis, lims in limits.items(): diff --git a/fast_plotter/plotting.py b/fast_plotter/plotting.py index 5594caa..d3c390b 100644 --- a/fast_plotter/plotting.py +++ b/fast_plotter/plotting.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import matplotlib.colors as mc import logging +import re logger = logging.getLogger(__name__) @@ -423,8 +424,41 @@ def _merge_datasets(df, style, dataset_col, param_name="_merge_datasets", err_fr return df -def add_annotations(annotations, ax): +def annotate_lines(cfg, main_ax, summary_ax): + linename = list(cfg.keys())[0] + annotDict = cfg[linename] + if 'values' not in annotDict.keys(): + raise(RuntimeError("Must provide values for line placement.")) + annotDefaults = {"style": "-", "alpha": 1, "width": 1.5, + "colour": 'k', "label": None, "vmin": 0, + "vmax": 1, "zorder": 10, "axes": ["main"]} + annotDict.update({key: value for key, value in annotDefaults.items() + if key not in annotDict.keys()}) + lineKeys = ['values', 'style', 'alpha', 'width', 'colour', 'label', 'vmin', 'vmax', 'zorder', 'axes'] + if set(annotDict.keys()).difference(set(lineKeys)): + logger.warn("Invalid parameter(s) given to line annotations. Options are {}".format(lineKeys)) + values, style, alpha, width, colour, label, vmin, vmax, zorder, axes = [annotDict[key] for key in lineKeys] + for axis in axes: + awidth = 0.6 * width if (axis == 'summary') else width + ax = main_ax if (str(axis) == 'main') else summary_ax if (str(axis) == 'summary') else None + if ax is None: + logger.warn("Axis must exist and either be 'main' or 'summary'. {} is None".format(axis)) + continue + for value in values: + value = float(value) + if 'hline' in linename: + ax.axhline(value, vmin, vmax, color=colour, label=label, + alpha=alpha, ls=style, lw=awidth, zorder=zorder) + if 'vline' in linename: + ax.axvline(value, vmin, vmax, color=colour, label=label, + alpha=alpha, ls=style, lw=awidth, zorder=zorder) + + +def add_annotations(annotations, ax, summary_ax=None): for cfg in annotations: + if list(filter(lambda key: re.match("(.*hline.*|.*vline.*)", key), cfg.keys())): + annotate_lines(cfg, ax, summary_ax) + continue cfg = cfg.copy() s = cfg.pop("text") xy = cfg.pop("position")