Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 99 additions & 26 deletions fast_plotter/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import matplotlib.colors as mc
import logging
import re
logger = logging.getLogger(__name__)


Expand All @@ -26,7 +27,7 @@ def change_brightness(color, amount):
def plot_all(df, project_1d=True, project_2d=True, data="data", signal=None, dataset_col="dataset",
yscale="log", lumi=None, annotations=[], dataset_order=None,
continue_errors=True, bin_variable_replacements={}, colourmap="nipy_spectral",
figsize=None, **kwargs):
figsize=None, other_dset_types={}, **kwargs):
figures = {}

dimensions = utils.binning_vars(df)
Expand All @@ -53,7 +54,7 @@ def plot_all(df, project_1d=True, project_2d=True, data="data", signal=None, dat
plot = plot_1d_many(projected, data=data, signal=signal,
dataset_col=dataset_col, scale_sims=lumi,
colourmap=colourmap, dataset_order=dataset_order,
figsize=figsize, **kwargs
figsize=figsize, other_dset_args=other_dset_types, **kwargs
)
figures[(("project", dim), ("yscale", yscale))] = plot
except Exception as e:
Expand Down Expand Up @@ -107,7 +108,8 @@ def get_colour(self, index=None, name=None):

class FillColl(object):
def __init__(self, n_colors=10, ax=None, fill=True, line=True, dataset_colours=None,
colourmap="nipy_spectral", dataset_order=None, linewidth=0.5, expected_xs=None):
colourmap="nipy_spectral", dataset_order=None, linewidth=0.5,
expected_xs=None, other_dset_args={}):
self.calls = -1
self.expected_xs = expected_xs
self.colors = ColorDict(n_colors=n_colors, order=dataset_order,
Expand All @@ -117,6 +119,8 @@ def __init__(self, n_colors=10, ax=None, fill=True, line=True, dataset_colours=N
self.fill = fill
self.line = line
self.linewidth = linewidth
self.other_dset_args = other_dset_args
self.dataset_colours = dataset_colours

def pre_call(self, column):
ax = self.ax
Expand All @@ -129,16 +133,26 @@ def pre_call(self, column):

def __call__(self, col, **kwargs):
ax, x, y, color = self.pre_call(col)
if self.fill:
if self.fill and not self.other_dset_args:
draw(ax, "fill_between", x=x, ys=["y1"],
y1=y, label=col.name, expected_xs=self.expected_xs,
linewidth=0, color=color, **kwargs)
if self.line:
if self.fill:
label = None
color = "k"
width = self.linewidth
style = "-"
if self.other_dset_args:
style = self.other_dset_args['style']
label = col.name if self.other_dset_args['add_label'] else None
color = self.other_dset_args['colour'] if self.other_dset_args['colour']\
else self.dataset_colours[col.name] if col.name in self.dataset_colours.keys()\
else color
self.color = color
self.other_dset_args['tmp_colour'] = color
width = self.linewidth
else:
style = "-"
label = None
color = "k"
width = self.linewidth
else:
color = None
label = col.name
Expand All @@ -162,7 +176,8 @@ def __call__(self, col, **kwargs):


def actually_plot(df, x_axis, y, yerr, kind, label, ax, dataset_col="dataset",
dataset_colours=None, colourmap="nipy_spectral", dataset_order=None):
dataset_colours=None, colourmap="nipy_spectral",
dataset_order=None, other_cfg_args={}):
expected_xs = df.index.unique(x_axis).values
if kind == "scatter":
draw(ax, "errorbar", x=df.reset_index()[x_axis], ys=["y", "yerr"], y=df[y], yerr=df[yerr],
Expand Down Expand Up @@ -202,6 +217,24 @@ def actually_plot(df, x_axis, y, yerr, kind, label, ax, dataset_col="dataset",
y_up = (summed[y] + summed[yerr]).values
draw(ax, "fill_between", x, ys=["y1", "y2"], y2=y_down, y1=y_up,
color="gray", alpha=0.7, expected_xs=expected_xs)
elif kind == "other_dset_types":
if 'regex' not in other_cfg_args:
raise RuntimeError("Must specify a regex for other plotting datatype to be applied to")
options = ["alpha", "style", "width", "add_label", "add_error", "regex"]
alpha, style, width, add_label, add_error, regex = [other_cfg_args[key] for key in options]
filler = FillColl(n_datasets, ax=ax, fill=True, colourmap=colourmap, dataset_colours=dataset_colours,
dataset_order=dataset_order, expected_xs=expected_xs, linewidth=width,
other_dset_args=other_cfg_args)
vals.apply(filler, axis=0, step="mid")
if add_error:
for dset in list(set(df.reset_index()[dataset_col])):
if not re.compile(regex).match(dset):
continue
color = filler.color
dset_df = df.reset_index().loc[df.reset_index()[dataset_col] == dset].reset_index()
x = dset_df[x_axis]
draw(ax, "fill_between", x, ys=["y1", "y2"], y1=dset_df.eval("sumw+sqrt(sumw2)"),
y2=dset_df.eval("sumw-sqrt(sumw2)"), color=color, alpha=alpha, expected_xs=expected_xs)
else:
raise RuntimeError("Unknown value for 'kind', '{}'".format(kind))

Expand Down Expand Up @@ -330,7 +363,7 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset",
kind_data="scatter", kind_sims="fill-error-last", kind_signal="line",
scale_sims=None, summary="ratio-error-both", colourmap="nipy_spectral",
dataset_order=None, figsize=(5, 6), show_over_underflow=False,
dataset_colours=None, err_from_sumw2=False, data_legend="Data", **kwargs):
dataset_colours=None, err_from_sumw2=False, data_legend="Data", other_dset_args={}, **kwargs):
y = "sumw"
yvar = "sumw2"
yerr = "err"
Expand All @@ -352,13 +385,39 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset",
else:
in_df_signal = None

config_extend = []
if other_dset_args:
for dset_type in other_dset_args.keys():
dset_type_labels = other_dset_args[dset_type]['regex']
other_defaults = {"style": "-", "alpha": 0.2, "width": 1,
"colour": [], "dset_type": dset_type, "add_label": True,
"add_error": True, "plot_ratio": False}
default_specs = {key: val for key, val
in other_defaults.items()
if key not in other_dset_args[dset_type].keys()}
other_dset_args[dset_type].update(default_specs)
in_df_other, in_df_sims = utils.split_data_sims(
in_df_sims, data_labels=dset_type_labels, dataset_level=dataset_col)
config_extend.append((in_df_other, None, "other_dset_types",
dset_type_labels, "plot_other_dset", other_dset_args[dset_type]))
else:
in_df_other = None

def_cfg_args = {"dset_type": ""}
config = [(in_df_sims, plot_sims, kind_sims, "Monte Carlo", "plot_sims", def_cfg_args),
(in_df_data, plot_data, kind_data, data_legend, "plot_data", def_cfg_args),
(in_df_signal, plot_signal, kind_signal, "Signal", "plot_signal", def_cfg_args),
]

config.extend(config_extend)

if in_df_data is None or in_df_sims is None:
summary = None
if not summary:
fig, main_ax = plt.subplots(1, 1, figsize=figsize)
fig, main_ax = plt.subplots(1, 1, figsize=[float(i) for i in figsize])
else:
fig, ax = plt.subplots(
2, 1, gridspec_kw={"height_ratios": (3, 1)}, sharex=True, figsize=figsize)
2, 1, gridspec_kw={"height_ratios": (3, 1)}, sharex=True, figsize=[float(i) for i in figsize])
fig.subplots_adjust(hspace=.1)
main_ax, summary_ax = ax

Expand All @@ -370,18 +429,14 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset",
"Too few dimensions to multiple 1D graphs, use plot_1d instead")
x_axis = x_axis[0]

config = [(in_df_sims, plot_sims, kind_sims, "Monte Carlo", "plot_sims"),
(in_df_data, plot_data, kind_data, data_legend, "plot_data"),
(in_df_signal, plot_signal, kind_signal, "Signal", "plot_signal"),
]
for df, combine, style, label, var_name in config:
for df, combine, style, label, var_name, other_cfg_args in config:
if df is None or len(df) == 0:
continue
merged = _merge_datasets(df, combine, dataset_col, param_name=var_name, err_from_sumw2=err_from_sumw2)
actually_plot(merged, x_axis=x_axis, y=y, yerr=yerr, kind=style,
label=label, ax=main_ax, dataset_col=dataset_col,
dataset_colours=dataset_colours,
colourmap=colourmap, dataset_order=dataset_order)
colourmap=colourmap, dataset_order=dataset_order, other_cfg_args=other_cfg_args)
main_ax.set_xlabel(x_axis)

if not summary:
Expand All @@ -406,6 +461,22 @@ def plot_1d_many(df, prefix="", data="data", signal=None, dataset_col="dataset",
plot_ratio(summed_data, summed_sims, x=x_axis,
y=y, yerr=yerr, ax=summary_ax, error=error,
ylim=kwargs["ratio_ylim"], ylabel=kwargs["ratio_ylabel"])
if other_dset_args:
for df, combine, style, label, var_name, other_dset_args in config:
if (style == "other_dset_types") and (other_dset_args['plot_ratio']):
error = "both"
dset = other_dset_args['dset_type']
color = dataset_colours[dset] if dset in dataset_colours\
else other_dset_args['colour'] if other_dset_args['colour']\
else other_dset_args['tmp_colour']
add_error = other_dset_args['add_error']
summed_dset = _merge_datasets(
df, "sum", dataset_col=dataset_col, err_from_sumw2=err_from_sumw2)
if summed_data is not None:
plot_ratio(summed_data, summed_dset, x=x_axis,
y=y, yerr=yerr, ax=summary_ax, error=error, zorder=21,
ylim=kwargs["ratio_ylim"], ylabel=kwargs["ratio_ylabel"],
color=color, add_error=add_error)
else:
raise RuntimeError(err_msg)
return main_ax, summary_ax
Expand Down Expand Up @@ -441,7 +512,8 @@ def plot_1d(df, kind="line", yscale="lin"):
return fig


def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="Data / MC"):
def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="Data / MC",
color="k", zorder=22, add_error=True):
# make sure both sides agree with the binning
merged = data.join(sims, how="left", lsuffix="data", rsuffix="sims")
data = merged.filter(like="data", axis="columns").fillna(0)
Expand All @@ -460,9 +532,10 @@ def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="D
mask = (central != 0) & (lower != 0)
ax.errorbar(x=x_axis[mask], y=central[mask], yerr=(lower[mask], upper[mask]),
fmt="o", markersize=4, color="k")
draw(ax, "errorbar", x_axis[mask], ys=["y", "yerr"],
y=central[mask], yerr=(lower[mask], upper[mask]),
fmt="o", markersize=4, color="k")
if add_error:
draw(ax, "errorbar", x_axis[mask], ys=["y", "yerr"],
y=central[mask], yerr=(lower[mask], upper[mask]),
fmt="o", markersize=4, color="gray", zorder=zorder-1)

elif error == "both":
ratio = d / s
Expand All @@ -471,10 +544,10 @@ def plot_ratio(data, sims, x, y, yerr, ax, error="both", ylim=[0., 2], ylabel="D

draw(ax, "errorbar", x_axis, ys=["y", "yerr"],
y=ratio, yerr=rel_d_err,
fmt="o", markersize=4, color="k")
draw(ax, "fill_between", x_axis, ys=["y1", "y2"],
y2=1 + rel_s_err, y1=1 - rel_s_err, fill_val=1,
color="gray", alpha=0.7)
fmt="o", markersize=4, color=color, zorder=zorder)
if add_error:
draw(ax, "fill_between", x_axis, ys=["y1", "y2"], color="gray",
y2=1 + rel_s_err, y1=1 - rel_s_err, fill_val=1, alpha=0.7, zorder=zorder-1)

ax.set_ylim(ylim)
ax.grid(True)
Expand Down