diff --git a/fast_plotter/__main__.py b/fast_plotter/__main__.py index 54221e5..5ff8eaa 100644 --- a/fast_plotter/__main__.py +++ b/fast_plotter/__main__.py @@ -5,12 +5,14 @@ import six import logging import matplotlib +import numpy as np +import numbers matplotlib.use('Agg') matplotlib.rcParams.update({'figure.autolayout': True}) from .version import __version__ # noqa -from .utils import read_binned_df, weighting_vars # noqa +from .utils import read_binned_df, weighting_vars, binning_vars # noqa from .utils import decipher_filename, mask_rows # noqa -from .plotting import plot_all, add_annotations # noqa +from .plotting import plot_all, add_annotations, is_intervals # noqa logger = logging.getLogger("fast_plotter") @@ -101,10 +103,75 @@ def recursive_replace(value, replacements): return args +def autoscale_values(args, df_filtered, weight, ylim_lower=0.1, legend_size=2): + if hasattr(args, "autoscale"): + data_rows = mask_rows(df_filtered, + regex=args.data, + level=args.dataset_col) + mc_rows = mask_rows(df_filtered, + regex="^((?!"+args.data+").)*$", + level=args.dataset_col) + if len(df_filtered.index.names) > 2: + logger.warn("Autoscaling not supported for multi-index dataframes") + limits = args.limits if 'limits' in args else {} + else: + if 'y' in args.autoscale: + if weight == "n": + max_y = df_filtered['sumw'].max() + else: + max_mc = df_filtered.loc[mc_rows, 'sumw'].max()*args.lumi + max_data = df_filtered.loc[data_rows, 'n'].max() if 'n' in df_filtered.columns else 0.1 + max_y = max(max_mc, max_data) + max_y = max_y if max_y >= 1 else 1 + if args.yscale == 'log': + ylim_upper_floor = int(np.floor(np.log10(max_y))) + y_buffer = (legend_size + 1 if ylim_upper_floor > 3 + else legend_size if ylim_upper_floor > 2 + else legend_size) # Buffer for legend + ylim_upper = float('1e'+str(ylim_upper_floor+y_buffer)) + ylim_lower = 1e-1 + else: + buffer_factor = 1 + 0.5*legend_size + ylim_upper = round(max_y*buffer_factor, -int(np.floor(np.log10(abs(max_y))))) # Buffer for legend + ylim = [ylim_lower, ylim_upper] + df_aboveMin = df_filtered.loc[df_filtered['sumw'] > ylim_lower/args.lumi] + else: + if 'limits' in args: + ylim = args.limits['y'] if 'y' in args.limits else None + else: + ylim = None + df_aboveMin = df_filtered.copy() + xcol = df_aboveMin.index.get_level_values(1) + if 'x' in args.autoscale: # Determine x-axis limits + if is_intervals(xcol): # If x-axis is interval, take right and leftmost intervals unless they are inf + max_x = xcol.right.max() if np.isfinite(xcol.right.max()) else xcol.left.max() + min_x = xcol.left.min() if np.isfinite(xcol.left.min()) else xcol.right.min() + if not np.isfinite(max_x) and hasattr(args, "show_over_underflow") and args.show_over_underflow: + logger.warn("Cannot autoscale overflow bin for x-axis. Removing.") + xlim = [min_x, max_x] + elif isinstance(xcol, numbers.Number): + xlim = [xcol.min, xcol.max] + else: + xlim = [-0.5, len(xcol.unique()) - 0.5] # For non-numeric x-axis (e.g. mtn range) + else: + if 'limits' in args: + xlim = args.limits['x'] if 'x' in args.limits else None + else: + xlim = None + + xlim = None if xlim is not None and np.NaN in xlim else xlim + ylim = None if ylim is not None and np.NaN in ylim else ylim + limits = {"x": xlim, "y": ylim} + else: + limits = args.limits if 'limits' in args else {} + return limits + + def process_one_file(infile, args): logger.info("Processing: " + infile) df = read_binned_df(infile, dtype={args.dataset_col: str}) weights = weighting_vars(df) + legend_size = args.legend_size if hasattr(args, "legend_size") else 2 ran_ok = True for weight in weights: if args.weights and weight not in args.weights: @@ -132,6 +199,7 @@ def process_one_file(infile, args): df_filtered = df_filtered.groupby(level=df.index.names).sum() plots, ok = plot_all(df_filtered, **vars(args)) ran_ok &= ok + args.limits = autoscale_values(args, df_filtered, weight, legend_size=legend_size) dress_main_plots(plots, **vars(args)) save_plots(infile, weight, plots, args.outdir, args.extension) return ran_ok @@ -153,6 +221,8 @@ def dress_main_plots(plots, annotations=[], yscale=None, ylabel=None, legend={}, lims = map(float, lims) if axis.lower() in "xy": getattr(main_ax, "set_%slim" % axis)(*lims) + elif lims is None: + continue elif lims.endswith("%"): main_ax.margins(**{axis: float(lims[:-1])}) if xtickrotation: