diff --git a/mplexporter/utils.py b/mplexporter/utils.py index 22f488a..f42ad1b 100644 --- a/mplexporter/utils.py +++ b/mplexporter/utils.py @@ -209,6 +209,29 @@ def get_text_style(text): return style +def tick_format_props(formatter, tickvalues, labels): + if isinstance(formatter, ticker.NullFormatter): + return "", "" + if isinstance(formatter, ticker.StrMethodFormatter): + convertor = StrMethodTickFormatterConvertor(formatter) + return convertor.output, "str_method" + if isinstance(formatter, ticker.PercentFormatter): + return { + "xmax": formatter.xmax, + "decimals": formatter.decimals, + "symbol": formatter.symbol, + }, "percent" + if hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): + return [text.get_text() for text in labels], "index" + if isinstance(formatter, ticker.FixedFormatter): + return list(formatter.seq), "fixed" + if isinstance(formatter, ticker.FuncFormatter) and tickvalues: + return [formatter(value, i) for i, value in enumerate(tickvalues)], "fixed" + if not any(label.get_visible() for label in labels): + return "", "" + return None, "" + + def get_axis_properties(axis): """Return the property dictionary for a matplotlib.Axis instance""" props = {} @@ -233,39 +256,20 @@ def get_axis_properties(axis): props['nticks'] = len(tick_locs) props['tickvalues'] = tick_locs if isinstance(locator, ticker.FixedLocator) else None + minor_locator = axis.get_minor_locator() + props['minor_tickvalues'] = list(axis.get_minorticklocs()) if minor_locator else None + props['minorticklength'] = axis._minor_tick_kw.get('size', None) + props['majorticklength'] = axis._major_tick_kw.get('size', None) + # Find tick formats - props['tickformat_formatter'] = "" - formatter = axis.get_major_formatter() - if isinstance(formatter, ticker.NullFormatter): - props['tickformat'] = "" - elif isinstance(formatter, ticker.StrMethodFormatter): - convertor = StrMethodTickFormatterConvertor(formatter) - props['tickformat'] = convertor.output - props['tickformat_formatter'] = "str_method" - elif isinstance(formatter, ticker.PercentFormatter): - props['tickformat'] = { - "xmax": formatter.xmax, - "decimals": formatter.decimals, - "symbol": formatter.symbol, - } - props['tickformat_formatter'] = "percent" - elif hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): - # IndexFormatter was dropped in matplotlib 3.5 - props['tickformat'] = [text.get_text() for text in axis.get_ticklabels()] - props['tickformat_formatter'] = "index" - elif isinstance(formatter, ticker.FixedFormatter): - props['tickformat'] = list(formatter.seq) - props['tickformat_formatter'] = "fixed" - elif isinstance(formatter, ticker.FuncFormatter): + major_formatter = axis.get_major_formatter() + if isinstance(major_formatter, ticker.FuncFormatter) and props['tickvalues'] is None: # It's impossible for JS to re-run our function, so run it now and save as Fixed. - if props['tickvalues'] is None: - props['tickvalues'] = tick_locs - props['tickformat'] = [formatter(value, i) for i, value in enumerate(props['tickvalues'])] - props['tickformat_formatter'] = "fixed" - elif not any(label.get_visible() for label in axis.get_ticklabels()): - props['tickformat'] = "" - else: - props['tickformat'] = None + props['tickvalues'] = tick_locs + props['minor_tickformat'], props['minor_tickformat_formatter'] = tick_format_props( + axis.get_minor_formatter(), props['minor_tickvalues'], axis.get_minorticklabels()) + props['tickformat'], props['tickformat_formatter'] = tick_format_props( + major_formatter, props['tickvalues'], axis.get_ticklabels()) # Get axis scale props['scale'] = axis.get_scale() @@ -279,6 +283,7 @@ def get_axis_properties(axis): # Get associated grid props['grid'] = get_grid_style(axis) + props['minor_grid'] = get_grid_style(axis, which='minor') # get axis visibility props['visible'] = axis.get_visible() @@ -286,21 +291,24 @@ def get_axis_properties(axis): return props -def get_grid_style(axis): - gridlines = axis.get_gridlines() - if axis._major_tick_kw['gridOn'] and len(gridlines) > 0: - color = export_color(gridlines[0].get_color()) - alpha = gridlines[0].get_alpha() - dasharray = get_dasharray(gridlines[0]) - linewidth = gridlines[0].get_linewidth() - return dict(gridOn=True, - color=color, - dasharray=dasharray, - linewidth=linewidth, - alpha=alpha) - else: +def get_grid_style(axis, which='major'): + tick_kw = axis._minor_tick_kw if which == 'minor' else axis._major_tick_kw + + if not tick_kw.get('gridOn'): return {"gridOn": False} + rc = matplotlib.rcParams + color = export_color(tick_kw.get('grid_color', tick_kw.get('grid_c', rc['grid.color']))) + alpha = tick_kw.get('grid_alpha', rc['grid.alpha']) + dasharray = dasharray_from_linestyle(tick_kw.get('grid_linestyle', tick_kw.get('grid_ls', rc['grid.linestyle']))) + linewidth = tick_kw.get('grid_linewidth', tick_kw.get('grid_lw', rc['grid.linewidth'])) + + return dict(gridOn=True, + color=color, + dasharray=dasharray, + linewidth=linewidth, + alpha=alpha) + def get_figure_properties(fig): return {'figwidth': fig.get_figwidth(),