diff --git a/fast_plotter/__main__.py b/fast_plotter/__main__.py index 54221e5..47346c1 100644 --- a/fast_plotter/__main__.py +++ b/fast_plotter/__main__.py @@ -5,12 +5,14 @@ import six import logging import matplotlib +import numpy as np +import numbers matplotlib.use('Agg') matplotlib.rcParams.update({'figure.autolayout': True}) from .version import __version__ # noqa -from .utils import read_binned_df, weighting_vars # noqa +from .utils import read_binned_df, weighting_vars, binning_vars # noqa from .utils import decipher_filename, mask_rows # noqa -from .plotting import plot_all, add_annotations # noqa +from .plotting import plot_all, add_annotations, is_intervals, annotate_xlabel_vals # noqa logger = logging.getLogger("fast_plotter") @@ -44,6 +46,7 @@ def arg_parser(args=None): help="Scale the MC yields by this lumi") parser.add_argument("-y", "--yscale", default="log", choices=["log", "linear"], help="Use this scale for the y-axis") + parser.add_argument("-a", "--annotate_xlabel", action="store_true", help="Split x-axis information onto plot") parser.add_argument('--version', action='version', version='%(prog)s ' + __version__) def split_equals(arg): @@ -94,17 +97,84 @@ def recursive_replace(value, replacements): if isinstance(value, six.string_types): return Template(value).safe_substitute(replacements) return value - replacements = dict(args.variables) args = Namespace(**recursive_replace(vars(args), replacements)) return args +def autoscale_values(args, df_filtered, weight, ylim_lower=0.5, legend_size=2): + if hasattr(args, "autoscale"): + legend_size = int(legend_size) + data_rows = mask_rows(df_filtered, + regex=args.data, + level=args.dataset_col) + mc_rows = mask_rows(df_filtered, + regex="^((?!"+args.data+").)*$", + level=args.dataset_col) + if len(df_filtered.index.names) > 2: + logger.warn("Autoscaling not supported for multi-index dataframes") + limits = args.limits if 'limits' in args else {} + else: + if 'y' in args.autoscale: + if weight == "n": + max_y = df_filtered['sumw'].max() + else: + max_mc = df_filtered.loc[mc_rows, 'sumw'].max()*args.lumi + max_data = df_filtered.loc[data_rows, 'n'].max() if 'n' in df_filtered.columns else 0.1 + max_y = max(max_mc, max_data) + max_y = max_y if max_y >= 1 else 1 + if args.yscale == 'log': + ylim_upper_floor = int(np.floor(np.log10(max_y))) + y_buffer = (legend_size + 1 if ylim_upper_floor > 3 + else legend_size if ylim_upper_floor > 2 + else legend_size) # Buffer for legend + ylim_upper = float('1e'+str(ylim_upper_floor+y_buffer)) + else: + buffer_factor = 1 + 0.5*legend_size + ylim_upper = round(max_y*buffer_factor, -int(np.floor(np.log10(abs(max_y))))) # Buffer for legend + ylim = [ylim_lower, ylim_upper] + df_aboveMin = df_filtered.loc[df_filtered['sumw'] > ylim_lower/args.lumi] + else: + if 'limits' in args: + ylim = args.limits['y'] if 'y' in args.limits else None + else: + ylim = None + if 'x' in args.autoscale: + df_aboveMin = df_filtered.loc[df_filtered['sumw'] > ylim_lower/args.lumi] + else: + df_aboveMin = df_filtered.copy() + xcol = df_aboveMin.index.get_level_values(1) + if 'x' in args.autoscale: # Determine x-axis limits + if is_intervals(xcol): # If x-axis is interval, take right and leftmost intervals unless they are inf + max_x = xcol.right.max() if np.isfinite(xcol.right.max()) else xcol.left.max() + min_x = xcol.left.min() if np.isfinite(xcol.left.min()) else xcol.right.min() + if not np.isfinite(max_x) and hasattr(args, "show_over_underflow") and args.show_over_underflow: + logger.warn("Cannot autoscale overflow bin for x-axis. Removing.") + xlim = [min_x, max_x] + elif isinstance(xcol, numbers.Number): + xlim = [xcol.min, xcol.max] + else: + xlim = [-0.5, len(xcol.unique()) - 0.5] # For non-numeric x-axis (e.g. mtn range) + else: + if 'limits' in args: + xlim = args.limits['x'] if 'x' in args.limits else None + else: + xlim = None + + xlim = None if xlim is not None and np.NaN in xlim else xlim + ylim = None if ylim is not None and np.NaN in ylim else ylim + limits = {"x": xlim, "y": ylim} + else: + limits = args.limits if 'limits' in args else {} + return limits + + def process_one_file(infile, args): logger.info("Processing: " + infile) df = read_binned_df(infile, dtype={args.dataset_col: str}) weights = weighting_vars(df) + legend_size = args.legend_size if hasattr(args, "legend_size") else 2 ran_ok = True for weight in weights: if args.weights and weight not in args.weights: @@ -132,29 +202,39 @@ def process_one_file(infile, args): df_filtered = df_filtered.groupby(level=df.index.names).sum() plots, ok = plot_all(df_filtered, **vars(args)) ran_ok &= ok - dress_main_plots(plots, **vars(args)) + args.limits = autoscale_values(args, df_filtered, weight, legend_size=legend_size) + dress_main_plots(plots, **vars(args), df=df_filtered) save_plots(infile, weight, plots, args.outdir, args.extension) return ran_ok def dress_main_plots(plots, annotations=[], yscale=None, ylabel=None, legend={}, - limits={}, xtickrotation=None, **kwargs): + limits={}, xtickrotation=None, df=None, annotate_xlabel=False, grid='both', **kwargs): for main_ax, summary_ax in plots.values(): - add_annotations(annotations, main_ax) + add_annotations(annotations, main_ax, summary_ax) + if annotate_xlabel: + met_cats=annotate_xlabel_vals(df, main_ax) if yscale: main_ax.set_yscale(yscale) if ylabel: main_ax.set_ylabel(ylabel) - main_ax.legend(**legend) - main_ax.grid(True) + legend['ncol'] = int(legend['ncol']) + main_ax.legend(**legend).set_zorder(20) + main_ax.grid(axis=grid) main_ax.set_axisbelow(True) for axis, lims in limits.items(): if isinstance(lims, (tuple, list)): lims = map(float, lims) if axis.lower() in "xy": getattr(main_ax, "set_%slim" % axis)(*lims) + elif lims is None: + continue elif lims.endswith("%"): main_ax.margins(**{axis: float(lims[:-1])}) + if annotate_xlabel: + x_ticks = [i for i in range(len(met_cats))] + main_ax.set_xticks(x_ticks) + main_ax.set_xticklabels(met_cats) if xtickrotation: matplotlib.pyplot.xticks(rotation=xtickrotation) diff --git a/fast_plotter/plotting.py b/fast_plotter/plotting.py index 5594caa..43a8697 100644 --- a/fast_plotter/plotting.py +++ b/fast_plotter/plotting.py @@ -6,9 +6,9 @@ import matplotlib.pyplot as plt import matplotlib.colors as mc import logging +import re logger = logging.getLogger(__name__) - def change_brightness(color, amount): if amount is None: return @@ -22,20 +22,68 @@ def change_brightness(color, amount): c = colorsys.rgb_to_hls(*color) return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) - +def annotate_xlabel_vals(df, ax, binning_col='category', regex="(?P.*?(?=\s))\s(?P\d.*?(?=\d))(?P.*?(?=,\s)),\s(?P.*)", backup_regex="(?P.*?(?=\,))(?P()),\s(?P.*)"): + df=df.reset_index() + re_compiler = lambda category,regex: re.compile(regex).match(str(category.replace("$","").replace("\infty","$\infty$"))) + compile_correct_regex = lambda category: (re_compiler(category,regex) if re_compiler(category,regex) is not None else re_compiler(category,backup_regex)).groups() + met_cats=[compile_correct_regex(category)[3:][-1] for category in df[binning_col].unique()] + cats=[compile_correct_regex(category)[:3] for category in df[binning_col].unique()] + n_cats = len(cats) + for i, cat in enumerate(cats): + if i==0: + a1,a2,a3=cat + old_cat = cat + labels = {i:{0:{val.strip():0}} for i,val in enumerate(cat)} + else: + for j, val in enumerate(cat): + val = val.strip() + if old_cat[j].strip() == val: + if j == len(cat)-1: + old_cat=cat + continue + else: + labels[j][i]={val:0} + if j == len(cat)-1: + old_cat=cat + for depth, label in labels.items(): + for i, split in enumerate(label): + label_str = list(label[split].keys())[0] + if i == len(label) - 1: + label_length = len(cats) - split + else: + label_length = dict(enumerate(label))[i+1] - split + labels[depth][split][label_str]=label_length + label_positions = {} + for depth, label in labels.items(): + label_positions[depth] = {} + for left_edge, len_dict in label.items(): + label_str = list(len_dict.keys())[0] + position = left_edge + (len_dict[label_str]/2) + if label_str in label_positions[depth]: + label_positions[depth][label_str].append(position-0.5) + else: + label_positions[depth][label_str] = [position-0.5] + + for depth, label_dict in label_positions.items(): + y = (0.80 - 0.05*(depth + 1)) + for label, xvals in label_dict.items(): + for x in xvals: + x = (x+0.5)/n_cats + ax.text(x, y, label, fontsize=12-depth, transform=ax.transAxes, ha='center', weight='medium') + return met_cats + def plot_all(df, project_1d=True, project_2d=True, data="data", signal=None, dataset_col="dataset", yscale="log", lumi=None, annotations=[], dataset_order=None, continue_errors=True, bin_variable_replacements={}, colourmap="nipy_spectral", - figsize=None, **kwargs): + figsize=None, other_dset_types={}, grid='both', **kwargs): figures = {} - dimensions = utils.binning_vars(df) ran_ok = True if len(dimensions) == 1: df = utils.rename_index(df, bin_variable_replacements) figures[(("yscale", yscale),)] = plot_1d( - df, yscale=yscale, annotations=annotations) + df, yscale=yscale, annotations=annotations, grid=grid) if dataset_col in dimensions: dimensions = tuple(dim for dim in dimensions if dim != dataset_col) @@ -53,7 +101,7 @@ def plot_all(df, project_1d=True, project_2d=True, data="data", signal=None, dat plot = plot_1d_many(projected, data=data, signal=signal, dataset_col=dataset_col, scale_sims=lumi, colourmap=colourmap, dataset_order=dataset_order, - figsize=figsize, **kwargs + figsize=figsize, other_dset_args=other_dset_types, grid=grid, **kwargs ) figures[(("project", dim), ("yscale", yscale))] = plot except Exception as e: @@ -107,7 +155,8 @@ def get_colour(self, index=None, name=None): class FillColl(object): def __init__(self, n_colors=10, ax=None, fill=True, line=True, dataset_colours=None, - colourmap="nipy_spectral", dataset_order=None, linewidth=0.5, expected_xs=None): + colourmap="nipy_spectral", dataset_order=None, linewidth=0.5, + expected_xs=None, other_dset_args={}): self.calls = -1 self.expected_xs = expected_xs self.colors = ColorDict(n_colors=n_colors, order=dataset_order, @@ -117,6 +166,8 @@ def __init__(self, n_colors=10, ax=None, fill=True, line=True, dataset_colours=N self.fill = fill self.line = line self.linewidth = linewidth + self.other_dset_args = other_dset_args + self.dataset_colours = dataset_colours def pre_call(self, column): ax = self.ax @@ -129,18 +180,28 @@ def pre_call(self, column): def __call__(self, col, **kwargs): ax, x, y, color = self.pre_call(col) - if self.fill: + if self.fill and not self.other_dset_args: draw(ax, "fill_between", x=x, ys=["y1"], y1=y, label=col.name, expected_xs=self.expected_xs, linewidth=0, color=color, **kwargs) if self.line: if self.fill: - label = None - color = "k" - width = self.linewidth - style = "-" + if self.other_dset_args: + style = self.other_dset_args['style'] + label = col.name if self.other_dset_args['add_label'] else None + color = self.other_dset_args['colour'] if self.other_dset_args['colour']\ + else self.dataset_colours[col.name] if col.name in self.dataset_colours.keys()\ + else color + self.color = color + self.other_dset_args['tmp_colour'] = color + width = self.linewidth + else: + style = "-" + label = None + color = "k" + width = self.linewidth else: - color = None + color = color label = col.name width = 2 style = "--" @@ -162,7 +223,8 @@ def __call__(self, col, **kwargs): def actually_plot(df, x_axis, y, yerr, kind, label, ax, dataset_col="dataset", - dataset_colours=None, colourmap="nipy_spectral", dataset_order=None): + dataset_colours=None, colourmap="nipy_spectral", + dataset_order=None, other_cfg_args={}, grid='both'): expected_xs = df.index.unique(x_axis).values if kind == "scatter": draw(ax, "errorbar", x=df.reset_index()[x_axis], ys=["y", "yerr"], y=df[y], yerr=df[yerr], @@ -202,6 +264,24 @@ def actually_plot(df, x_axis, y, yerr, kind, label, ax, dataset_col="dataset", y_up = (summed[y] + summed[yerr]).values draw(ax, "fill_between", x, ys=["y1", "y2"], y2=y_down, y1=y_up, color="gray", alpha=0.7, expected_xs=expected_xs) + elif kind == "other_dset_types": + if 'regex' not in other_cfg_args: + raise RuntimeError("Must specify a regex for other plotting datatype to be applied to") + options = ["alpha", "style", "width", "add_label", "add_error", "regex"] + alpha, style, width, add_label, add_error, regex = [other_cfg_args[key] for key in options] + filler = FillColl(n_datasets, ax=ax, fill=True, colourmap=colourmap, dataset_colours=dataset_colours, + dataset_order=dataset_order, expected_xs=expected_xs, linewidth=width, + other_dset_args=other_cfg_args) + vals.apply(filler, axis=0, step="mid") + if add_error: + for dset in list(set(df.reset_index()[dataset_col])): + if not re.compile(regex).match(dset): + continue + color = filler.color + dset_df = df.reset_index().loc[df.reset_index()[dataset_col] == dset].reset_index() + x = dset_df[x_axis] + draw(ax, "fill_between", x, ys=["y1", "y2"], y1=dset_df.eval("sumw+sqrt(sumw2)"), + y2=dset_df.eval("sumw-sqrt(sumw2)"), color=color, alpha=alpha, expected_xs=expected_xs) else: raise RuntimeError("Unknown value for 'kind', '{}'".format(kind)) @@ -330,7 +410,7 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", kind_data="scatter", kind_sims="fill-error-last", kind_signal="line", scale_sims=None, summary="ratio-error-both", colourmap="nipy_spectral", dataset_order=None, figsize=(5, 6), show_over_underflow=False, - dataset_colours=None, err_from_sumw2=False, data_legend="Data", **kwargs): + dataset_colours=None, err_from_sumw2=False, data_legend="Data", other_dset_args={}, grid='both', **kwargs): y = "sumw" yvar = "sumw2" yerr = "err" @@ -352,6 +432,33 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", else: in_df_signal = None + config_extend = [] + if other_dset_args: + for dset_type in other_dset_args.keys(): + dset_type_labels = other_dset_args[dset_type]['regex'] + other_defaults = {"style": "-", "alpha": 0.2, "width": 1, + "colour": [], "dset_type": dset_type, "add_label": True, + "add_error": True, "plot_ratio": False} + default_specs = {key: val for key, val + in other_defaults.items() + if key not in other_dset_args[dset_type].keys()} + other_dset_args[dset_type].update(default_specs) + in_df_other, in_df_sims = utils.split_data_sims( + in_df_sims, data_labels=dset_type_labels, dataset_level=dataset_col) + config_extend.append((in_df_other, None, "other_dset_types", + dset_type_labels, "plot_other_dset", other_dset_args[dset_type])) + else: + in_df_other = None + + def_cfg_args = {"dset_type": ""} + config = [(in_df_sims, plot_sims, kind_sims, "Monte Carlo", "plot_sims", def_cfg_args), + (in_df_data, plot_data, kind_data, data_legend, "plot_data", def_cfg_args), + (in_df_signal, plot_signal, kind_signal, "Signal", "plot_signal", def_cfg_args), + ] + + config.extend(config_extend) + + figsize = [float(i) for i in figsize] if figsize else None if in_df_data is None or in_df_sims is None: summary = None if not summary: @@ -370,18 +477,16 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", "Too few dimensions to multiple 1D graphs, use plot_1d instead") x_axis = x_axis[0] - config = [(in_df_sims, plot_sims, kind_sims, "Monte Carlo", "plot_sims"), - (in_df_data, plot_data, kind_data, data_legend, "plot_data"), - (in_df_signal, plot_signal, kind_signal, "Signal", "plot_signal"), - ] - for df, combine, style, label, var_name in config: + kwargs.setdefault("is_null_poissonian", False) + for df, combine, style, label, var_name, other_cfg_args in config: if df is None or len(df) == 0: continue - merged = _merge_datasets(df, combine, dataset_col, param_name=var_name, err_from_sumw2=err_from_sumw2) + merged = _merge_datasets(df, combine, dataset_col, param_name=var_name, err_from_sumw2=err_from_sumw2, + is_null_poissonian=kwargs['is_null_poissonian']) actually_plot(merged, x_axis=x_axis, y=y, yerr=yerr, kind=style, label=label, ax=main_ax, dataset_col=dataset_col, dataset_colours=dataset_colours, - colourmap=colourmap, dataset_order=dataset_order) + colourmap=colourmap, dataset_order=dataset_order, other_cfg_args=other_cfg_args, grid=grid) main_ax.set_xlabel(x_axis) if not summary: @@ -392,9 +497,11 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", if summary.startswith("ratio"): main_ax.set_xlabel("") summed_data = _merge_datasets( - in_df_data, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2) + in_df_data, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2, + is_null_poissonian=kwargs['is_null_poissonian']) summed_sims = _merge_datasets( - in_df_sims, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2) + in_df_sims, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2, + is_null_poissonian=kwargs['is_null_poissonian']) if summary == "ratio-error-both": error = "both" elif summary == "ratio-error-markers": @@ -404,14 +511,31 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset", kwargs.setdefault("ratio_ylim", [0., 2.]) kwargs.setdefault("ratio_ylabel", "Data / MC") plot_ratio(summed_data, summed_sims, x=x_axis, - y=y, yerr=yerr, ax=summary_ax, error=error, + y=y, yerr=yerr, ax=summary_ax, error=error, grid=grid, ylim=kwargs["ratio_ylim"], ylabel=kwargs["ratio_ylabel"]) + if other_dset_args: + for df, combine, style, label, var_name, other_dset_args in config: + if (style == "other_dset_types") and (other_dset_args['plot_ratio']): + error = "both" + dset = other_dset_args['dset_type'] + color = dataset_colours[dset] if dset in dataset_colours\ + else other_dset_args['colour'] if other_dset_args['colour']\ + else other_dset_args['tmp_colour'] + add_error = other_dset_args['add_error'] + summed_dset = _merge_datasets( + df, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2) + if summed_data is not None: + plot_ratio(summed_data, summed_dset, x=x_axis, + y=y, yerr=yerr, ax=summary_ax, error=error, zorder=21, + ylim=kwargs["ratio_ylim"], ylabel=kwargs["ratio_ylabel"], grid=grid, + color=color, add_error=add_error) else: raise RuntimeError(err_msg) return main_ax, summary_ax -def _merge_datasets(df, style, dataset_col, param_name="_merge_datasets", err_from_sumw2=False): +def _merge_datasets(df, style, dataset_col, param_name="_merge_datasets", err_from_sumw2=False, + is_null_poissonian=False): if style == "stack": df = utils.stack_datasets(df, dataset_level=dataset_col) elif style == "sum": @@ -419,29 +543,62 @@ def _merge_datasets(df, style, dataset_col, param_name="_merge_datasets", err_fr elif style: msg = "'{}' must be either 'sum', 'stack' or None. Got {}" raise RuntimeError(msg.format(param_name, style)) - utils.calculate_error(df, do_rel_err=not err_from_sumw2) + utils.calculate_error(df, do_rel_err=not err_from_sumw2, is_null_poissonian=is_null_poissonian) return df -def add_annotations(annotations, ax): +def annotate_lines(cfg, main_ax, summary_ax): + linename = list(cfg.keys())[0] + annotDict = cfg[linename] + if 'values' not in annotDict.keys(): + raise(RuntimeError("Must provide values for line placement.")) + annotDefaults = {"style": "-", "alpha": 1, "width": 1.5, + "colour": 'k', "label": None, "vmin": 0, + "vmax": 1, "zorder": 10, "axes": ["main"]} + annotDict.update({key: value for key, value in annotDefaults.items() + if key not in annotDict.keys()}) + lineKeys = ['values', 'style', 'alpha', 'width', 'colour', 'label', 'vmin', 'vmax', 'zorder', 'axes'] + if set(annotDict.keys()).difference(set(lineKeys)): + logger.warn("Invalid parameter(s) given to line annotations. Options are {}".format(lineKeys)) + values, style, alpha, width, colour, label, vmin, vmax, zorder, axes = [annotDict[key] for key in lineKeys] + for axis in axes: + awidth = 0.6 * width if (axis == 'summary') else width + ax = main_ax if (str(axis) == 'main') else summary_ax if (str(axis) == 'summary') else None + if ax is None: + logger.warn("Axis must exist and either be 'main' or 'summary'. {} is None".format(axis)) + continue + for value in values: + value = float(value) + if 'hline' in linename: + ax.axhline(value, vmin, vmax, color=colour, label=label, + alpha=alpha, ls=style, lw=awidth, zorder=zorder) + if 'vline' in linename: + ax.axvline(value, vmin, vmax, color=colour, label=label, + alpha=alpha, ls=style, lw=awidth, zorder=zorder) + + +def add_annotations(annotations, ax, summary_ax=None): for cfg in annotations: + if list(filter(lambda key: re.match("(.*hline.*|.*vline.*)", key), cfg.keys())): + annotate_lines(cfg, ax, summary_ax) + continue cfg = cfg.copy() s = cfg.pop("text") xy = cfg.pop("position") cfg.setdefault("xycoords", "axes fraction") ax.annotate(s, xy=xy, **cfg) - -def plot_1d(df, kind="line", yscale="lin"): +def plot_1d(df, kind="line", yscale="lin", grid='both'): fig, ax = plt.subplots(1) df["sumw"].plot(kind=kind) ax.set_axisbelow(True) - plt.grid(True) + plt.grid(axis=grid) plt.yscale(yscale) return fig -def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="Data / MC"): +def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="Data / MC", + color="k", zorder=22, add_error=True, grid='both'): # make sure both sides agree with the binning merged = data.join(sims, how="left", lsuffix="data", rsuffix="sims") data = merged.filter(like="data", axis="columns").fillna(0) @@ -460,9 +617,10 @@ def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="D mask = (central != 0) & (lower != 0) ax.errorbar(x=x_axis[mask], y=central[mask], yerr=(lower[mask], upper[mask]), fmt="o", markersize=4, color="k") - draw(ax, "errorbar", x_axis[mask], ys=["y", "yerr"], - y=central[mask], yerr=(lower[mask], upper[mask]), - fmt="o", markersize=4, color="k") + if add_error: + draw(ax, "errorbar", x_axis[mask], ys=["y", "yerr"], + y=central[mask], yerr=(lower[mask], upper[mask]), + fmt="o", markersize=4, color="gray", zorder=zorder-1) elif error == "both": ratio = d / s @@ -471,13 +629,13 @@ def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="D draw(ax, "errorbar", x_axis, ys=["y", "yerr"], y=ratio, yerr=rel_d_err, - fmt="o", markersize=4, color="k") - draw(ax, "fill_between", x_axis, ys=["y1", "y2"], - y2=1 + rel_s_err, y1=1 - rel_s_err, fill_val=1, - color="gray", alpha=0.7) + fmt="o", markersize=4, color=color, zorder=zorder) + if add_error: + draw(ax, "fill_between", x_axis, ys=["y1", "y2"], color="gray", + y2=1 + rel_s_err, y1=1 - rel_s_err, fill_val=1, alpha=0.7, zorder=zorder-1) ax.set_ylim(ylim) - ax.grid(True) + ax.grid(axis=grid) ax.set_axisbelow(True) ax.set_xlabel(x) ax.set_ylabel(ylabel) diff --git a/fast_plotter/utils.py b/fast_plotter/utils.py index b765452..1cfbf57 100644 --- a/fast_plotter/utils.py +++ b/fast_plotter/utils.py @@ -91,7 +91,7 @@ def split_data_sims(df, data_labels=["data"], dataset_level="dataset"): return split_df(df, first_values=data_labels, level=dataset_level) -def calculate_error(df, sumw2_label="sumw2", err_label="err", inplace=True, do_rel_err=True): +def calculate_error(df, sumw2_label="sumw2", err_label="err", inplace=True, do_rel_err=True, is_null_poissonian=False): if not inplace: df = df.copy() if do_rel_err: @@ -105,6 +105,10 @@ def calculate_error(df, sumw2_label="sumw2", err_label="err", inplace=True, do_r elif not do_rel_err and sumw2_label in column: err_name = column.replace(sumw2_label, err_label) df[err_name] = np.sqrt(df[column]) + else: + continue + if is_null_poissonian: + df[err_name] = df[err_name].apply(lambda x: x if x > 1.15 else np.sqrt(1.15**2+x**2)) if not inplace: return df