diff --git a/fast_plotter/plotting.py b/fast_plotter/plotting.py index 5594caa..d9b08e9 100644 --- a/fast_plotter/plotting.py +++ b/fast_plotter/plotting.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import matplotlib.colors as mc import logging +import re logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ def change_brightness(color, amount): def plot_all(df, project_1d=True, project_2d=True, data="data", signal=None, dataset_col="dataset", yscale="log", lumi=None, annotations=[], dataset_order=None, continue_errors=True, bin_variable_replacements={}, colourmap="nipy_spectral", - figsize=None, **kwargs): + figsize=None, other_dset_types={}, **kwargs): figures = {} dimensions = utils.binning_vars(df) @@ -53,7 +54,7 @@ def plot_all(df, project_1d=True, project_2d=True, data="data", signal=None, dat plot = plot_1d_many(projected, data=data, signal=signal, dataset_col=dataset_col, scale_sims=lumi, colourmap=colourmap, dataset_order=dataset_order, - figsize=figsize, **kwargs + figsize=figsize, other_dset_args=other_dset_types, **kwargs ) figures[(("project", dim), ("yscale", yscale))] = plot except Exception as e: @@ -107,7 +108,8 @@ def get_colour(self, index=None, name=None): class FillColl(object): def __init__(self, n_colors=10, ax=None, fill=True, line=True, dataset_colours=None, - colourmap="nipy_spectral", dataset_order=None, linewidth=0.5, expected_xs=None): + colourmap="nipy_spectral", dataset_order=None, linewidth=0.5, + expected_xs=None, other_dset_args={}): self.calls = -1 self.expected_xs = expected_xs self.colors = ColorDict(n_colors=n_colors, order=dataset_order, @@ -117,6 +119,8 @@ def __init__(self, n_colors=10, ax=None, fill=True, line=True, dataset_colours=N self.fill = fill self.line = line self.linewidth = linewidth + self.other_dset_args = other_dset_args + self.dataset_colours = dataset_colours def pre_call(self, column): ax = self.ax @@ -129,16 +133,26 @@ def pre_call(self, column): def __call__(self, col, **kwargs): ax, x, y, color = self.pre_call(col) - if self.fill: + if self.fill and not self.other_dset_args: draw(ax, "fill_between", x=x, ys=["y1"], y1=y, label=col.name, expected_xs=self.expected_xs, linewidth=0, color=color, **kwargs) if self.line: if self.fill: - label = None - color = "k" - width = self.linewidth - style = "-" + if self.other_dset_args: + style = self.other_dset_args['style'] + label = col.name if self.other_dset_args['add_label'] else None + color = self.other_dset_args['colour'] if self.other_dset_args['colour']\ + else self.dataset_colours[col.name] if col.name in self.dataset_colours.keys()\ + else color + self.color = color + self.other_dset_args['tmp_colour'] = color + width = self.linewidth + else: + style = "-" + label = None + color = "k" + width = self.linewidth else: color = None label = col.name @@ -162,7 +176,8 @@ def __call__(self, col, **kwargs): def actually_plot(df, x_axis, y, yerr, kind, label, ax, dataset_col="dataset", - dataset_colours=None, colourmap="nipy_spectral", dataset_order=None): + dataset_colours=None, colourmap="nipy_spectral", + dataset_order=None, other_cfg_args={}): expected_xs = df.index.unique(x_axis).values if kind == "scatter": draw(ax, "errorbar", x=df.reset_index()[x_axis], ys=["y", "yerr"], y=df[y], yerr=df[yerr], @@ -202,6 +217,24 @@ def actually_plot(df, x_axis, y, yerr, kind, label, ax, dataset_col="dataset", y_up = (summed[y] + summed[yerr]).values draw(ax, "fill_between", x, ys=["y1", "y2"], y2=y_down, y1=y_up, color="gray", alpha=0.7, expected_xs=expected_xs) + elif kind == "other_dset_types": + if 'regex' not in other_cfg_args: + raise RuntimeError("Must specify a regex for other plotting datatype to be applied to") + options = ["alpha", "style", "width", "add_label", "add_error", "regex"] + alpha, style, width, add_label, add_error, regex = [other_cfg_args[key] for key in options] + filler = FillColl(n_datasets, ax=ax, fill=True, colourmap=colourmap, dataset_colours=dataset_colours, + dataset_order=dataset_order, expected_xs=expected_xs, linewidth=width, + other_dset_args=other_cfg_args) + vals.apply(filler, axis=0, step="mid") + if add_error: + for dset in list(set(df.reset_index()[dataset_col])): + if not re.compile(regex).match(dset): + continue + color = filler.color + dset_df = df.reset_index().loc[df.reset_index()[dataset_col] == dset].reset_index() + x = dset_df[x_axis] + draw(ax, "fill_between", x, ys=["y1", "y2"], y1=dset_df.eval("sumw+sqrt(sumw2)"), + y2=dset_df.eval("sumw-sqrt(sumw2)"), color=color, alpha=alpha, expected_xs=expected_xs) else: raise RuntimeError("Unknown value for 'kind', '{}'".format(kind)) @@ -330,7 +363,7 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", kind_data="scatter", kind_sims="fill-error-last", kind_signal="line", scale_sims=None, summary="ratio-error-both", colourmap="nipy_spectral", dataset_order=None, figsize=(5, 6), show_over_underflow=False, - dataset_colours=None, err_from_sumw2=False, data_legend="Data", **kwargs): + dataset_colours=None, err_from_sumw2=False, data_legend="Data", other_dset_args={}, **kwargs): y = "sumw" yvar = "sumw2" yerr = "err" @@ -352,13 +385,39 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", else: in_df_signal = None + config_extend = [] + if other_dset_args: + for dset_type in other_dset_args.keys(): + dset_type_labels = other_dset_args[dset_type]['regex'] + other_defaults = {"style": "-", "alpha": 0.2, "width": 1, + "colour": [], "dset_type": dset_type, "add_label": True, + "add_error": True, "plot_ratio": False} + default_specs = {key: val for key, val + in other_defaults.items() + if key not in other_dset_args[dset_type].keys()} + other_dset_args[dset_type].update(default_specs) + in_df_other, in_df_sims = utils.split_data_sims( + in_df_sims, data_labels=dset_type_labels, dataset_level=dataset_col) + config_extend.append((in_df_other, None, "other_dset_types", + dset_type_labels, "plot_other_dset", other_dset_args[dset_type])) + else: + in_df_other = None + + def_cfg_args = {"dset_type": ""} + config = [(in_df_sims, plot_sims, kind_sims, "Monte Carlo", "plot_sims", def_cfg_args), + (in_df_data, plot_data, kind_data, data_legend, "plot_data", def_cfg_args), + (in_df_signal, plot_signal, kind_signal, "Signal", "plot_signal", def_cfg_args), + ] + + config.extend(config_extend) + if in_df_data is None or in_df_sims is None: summary = None if not summary: - fig, main_ax = plt.subplots(1, 1, figsize=figsize) + fig, main_ax = plt.subplots(1, 1, figsize=[float(i) for i in figsize]) else: fig, ax = plt.subplots( - 2, 1, gridspec_kw={"height_ratios": (3, 1)}, sharex=True, figsize=figsize) + 2, 1, gridspec_kw={"height_ratios": (3, 1)}, sharex=True, figsize=[float(i) for i in figsize]) fig.subplots_adjust(hspace=.1) main_ax, summary_ax = ax @@ -370,18 +429,14 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", "Too few dimensions to multiple 1D graphs, use plot_1d instead") x_axis = x_axis[0] - config = [(in_df_sims, plot_sims, kind_sims, "Monte Carlo", "plot_sims"), - (in_df_data, plot_data, kind_data, data_legend, "plot_data"), - (in_df_signal, plot_signal, kind_signal, "Signal", "plot_signal"), - ] - for df, combine, style, label, var_name in config: + for df, combine, style, label, var_name, other_cfg_args in config: if df is None or len(df) == 0: continue merged = _merge_datasets(df, combine, dataset_col, param_name=var_name, err_from_sumw2=err_from_sumw2) actually_plot(merged, x_axis=x_axis, y=y, yerr=yerr, kind=style, label=label, ax=main_ax, dataset_col=dataset_col, dataset_colours=dataset_colours, - colourmap=colourmap, dataset_order=dataset_order) + colourmap=colourmap, dataset_order=dataset_order, other_cfg_args=other_cfg_args) main_ax.set_xlabel(x_axis) if not summary: @@ -406,6 +461,22 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", plot_ratio(summed_data, summed_sims, x=x_axis, y=y, yerr=yerr, ax=summary_ax, error=error, ylim=kwargs["ratio_ylim"], ylabel=kwargs["ratio_ylabel"]) + if other_dset_args: + for df, combine, style, label, var_name, other_dset_args in config: + if (style == "other_dset_types") and (other_dset_args['plot_ratio']): + error = "both" + dset = other_dset_args['dset_type'] + color = dataset_colours[dset] if dset in dataset_colours\ + else other_dset_args['colour'] if other_dset_args['colour']\ + else other_dset_args['tmp_colour'] + add_error = other_dset_args['add_error'] + summed_dset = _merge_datasets( + df, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2) + if summed_data is not None: + plot_ratio(summed_data, summed_dset, x=x_axis, + y=y, yerr=yerr, ax=summary_ax, error=error, zorder=21, + ylim=kwargs["ratio_ylim"], ylabel=kwargs["ratio_ylabel"], + color=color, add_error=add_error) else: raise RuntimeError(err_msg) return main_ax, summary_ax @@ -441,7 +512,8 @@ def plot_1d(df, kind="line", yscale="lin"): return fig -def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="Data / MC"): +def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="Data / MC", + color="k", zorder=22, add_error=True): # make sure both sides agree with the binning merged = data.join(sims, how="left", lsuffix="data", rsuffix="sims") data = merged.filter(like="data", axis="columns").fillna(0) @@ -460,9 +532,10 @@ def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="D mask = (central != 0) & (lower != 0) ax.errorbar(x=x_axis[mask], y=central[mask], yerr=(lower[mask], upper[mask]), fmt="o", markersize=4, color="k") - draw(ax, "errorbar", x_axis[mask], ys=["y", "yerr"], - y=central[mask], yerr=(lower[mask], upper[mask]), - fmt="o", markersize=4, color="k") + if add_error: + draw(ax, "errorbar", x_axis[mask], ys=["y", "yerr"], + y=central[mask], yerr=(lower[mask], upper[mask]), + fmt="o", markersize=4, color="gray", zorder=zorder-1) elif error == "both": ratio = d / s @@ -471,10 +544,10 @@ def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="D draw(ax, "errorbar", x_axis, ys=["y", "yerr"], y=ratio, yerr=rel_d_err, - fmt="o", markersize=4, color="k") - draw(ax, "fill_between", x_axis, ys=["y1", "y2"], - y2=1 + rel_s_err, y1=1 - rel_s_err, fill_val=1, - color="gray", alpha=0.7) + fmt="o", markersize=4, color=color, zorder=zorder) + if add_error: + draw(ax, "fill_between", x_axis, ys=["y1", "y2"], color="gray", + y2=1 + rel_s_err, y1=1 - rel_s_err, fill_val=1, alpha=0.7, zorder=zorder-1) ax.set_ylim(ylim) ax.grid(True)