Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions fast_plotter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import six
import logging
import matplotlib
import numpy as np
import numbers
matplotlib.use('Agg')
matplotlib.rcParams.update({'figure.autolayout': True})
from .version import __version__ # noqa
from .utils import read_binned_df, weighting_vars # noqa
from .utils import read_binned_df, weighting_vars, binning_vars # noqa
from .utils import decipher_filename, mask_rows # noqa
from .plotting import plot_all, add_annotations # noqa
from .plotting import plot_all, add_annotations, is_intervals # noqa


logger = logging.getLogger("fast_plotter")
Expand Down Expand Up @@ -101,10 +103,75 @@ def recursive_replace(value, replacements):
return args


def autoscale_values(args, df_filtered, weight, ylim_lower=0.1, legend_size=2):
if hasattr(args, "autoscale"):
data_rows = mask_rows(df_filtered,
regex=args.data,
level=args.dataset_col)
mc_rows = mask_rows(df_filtered,
regex="^((?!"+args.data+").)*$",
level=args.dataset_col)
if len(df_filtered.index.names) > 2:
logger.warn("Autoscaling not supported for multi-index dataframes")
limits = args.limits if 'limits' in args else {}
else:
if 'y' in args.autoscale:
if weight == "n":
max_y = df_filtered['sumw'].max()
else:
max_mc = df_filtered.loc[mc_rows, 'sumw'].max()*args.lumi
max_data = df_filtered.loc[data_rows, 'n'].max() if 'n' in df_filtered.columns else 0.1
max_y = max(max_mc, max_data)
max_y = max_y if max_y >= 1 else 1
if args.yscale == 'log':
ylim_upper_floor = int(np.floor(np.log10(max_y)))
y_buffer = (legend_size + 1 if ylim_upper_floor > 3
else legend_size if ylim_upper_floor > 2
else legend_size) # Buffer for legend
ylim_upper = float('1e'+str(ylim_upper_floor+y_buffer))
ylim_lower = 1e-1
else:
buffer_factor = 1 + 0.5*legend_size
ylim_upper = round(max_y*buffer_factor, -int(np.floor(np.log10(abs(max_y))))) # Buffer for legend
ylim = [ylim_lower, ylim_upper]
df_aboveMin = df_filtered.loc[df_filtered['sumw'] > ylim_lower/args.lumi]
else:
if 'limits' in args:
ylim = args.limits['y'] if 'y' in args.limits else None
else:
ylim = None
df_aboveMin = df_filtered.copy()
xcol = df_aboveMin.index.get_level_values(1)
if 'x' in args.autoscale: # Determine x-axis limits
if is_intervals(xcol): # If x-axis is interval, take right and leftmost intervals unless they are inf
max_x = xcol.right.max() if np.isfinite(xcol.right.max()) else xcol.left.max()
min_x = xcol.left.min() if np.isfinite(xcol.left.min()) else xcol.right.min()
if not np.isfinite(max_x) and hasattr(args, "show_over_underflow") and args.show_over_underflow:
logger.warn("Cannot autoscale overflow bin for x-axis. Removing.")
xlim = [min_x, max_x]
elif isinstance(xcol, numbers.Number):
xlim = [xcol.min, xcol.max]
else:
xlim = [-0.5, len(xcol.unique()) - 0.5] # For non-numeric x-axis (e.g. mtn range)
else:
if 'limits' in args:
xlim = args.limits['x'] if 'x' in args.limits else None
else:
xlim = None

xlim = None if xlim is not None and np.NaN in xlim else xlim
ylim = None if ylim is not None and np.NaN in ylim else ylim
limits = {"x": xlim, "y": ylim}
else:
limits = args.limits if 'limits' in args else {}
return limits


def process_one_file(infile, args):
logger.info("Processing: " + infile)
df = read_binned_df(infile, dtype={args.dataset_col: str})
weights = weighting_vars(df)
legend_size = args.legend_size if hasattr(args, "legend_size") else 2
ran_ok = True
for weight in weights:
if args.weights and weight not in args.weights:
Expand Down Expand Up @@ -132,6 +199,7 @@ def process_one_file(infile, args):
df_filtered = df_filtered.groupby(level=df.index.names).sum()
plots, ok = plot_all(df_filtered, **vars(args))
ran_ok &= ok
args.limits = autoscale_values(args, df_filtered, weight, legend_size=legend_size)
dress_main_plots(plots, **vars(args))
save_plots(infile, weight, plots, args.outdir, args.extension)
return ran_ok
Expand All @@ -153,6 +221,8 @@ def dress_main_plots(plots, annotations=[], yscale=None, ylabel=None, legend={},
lims = map(float, lims)
if axis.lower() in "xy":
getattr(main_ax, "set_%slim" % axis)(*lims)
elif lims is None:
continue
elif lims.endswith("%"):
main_ax.margins(**{axis: float(lims[:-1])})
if xtickrotation:
Expand Down