|
| 1 | +import matplotlib.dates as mdates |
| 2 | +import matplotlib.ticker as ticker |
| 3 | +import numpy as np |
| 4 | +from ._data import _locate_channel_data, _locate_channel_dtype, _locate_channel_scale, _locate_channel_axis, _convert_to_mpl_date |
| 5 | + |
| 6 | + |
| 7 | +def _set_limits(channel, scale): |
| 8 | + """Set the axis limits on the Matplotlib axis |
| 9 | +
|
| 10 | + Parameters |
| 11 | + ---------- |
| 12 | + channel : dict |
| 13 | + The mapping of the channel data and metadata |
| 14 | + scale : dict |
| 15 | + The mapping of the scale metadata and the scale data |
| 16 | + """ |
| 17 | + |
| 18 | + _axis_kwargs = { |
| 19 | + 'x': {'min': 'left', 'max': 'right'}, |
| 20 | + 'y': {'min': 'bottom', 'max': 'top'}, |
| 21 | + } |
| 22 | + |
| 23 | + lims = {} |
| 24 | + |
| 25 | + if channel['dtype'] == 'quantitative': |
| 26 | + # determine limits |
| 27 | + if 'domain' in scale: # domain takes precedence over zero in Altair |
| 28 | + if scale['domain'] == 'unaggregated': |
| 29 | + raise NotImplementedError |
| 30 | + else: |
| 31 | + lims[_axis_kwargs[channel['axis']].get('min')] = scale['domain'][0] |
| 32 | + lims[_axis_kwargs[channel['axis']].get('max')] = scale['domain'][1] |
| 33 | + elif 'type' in scale and scale['type'] != 'linear': |
| 34 | + lims = _set_scale_type(channel, scale) |
| 35 | + else: |
| 36 | + # Check that a positive minimum is zero if zero is True: |
| 37 | + if ('zero' not in scale or scale['zero'] == True) and min(channel['data']) > 0: |
| 38 | + lims[_axis_kwargs[channel['axis']].get('min')] = 0 # quantitative sets min to be 0 by default |
| 39 | + |
| 40 | + # Check that a negative maximum is zero if zero is True: |
| 41 | + if ('zero' not in scale or scale['zero'] == True) and max(channel['data']) < 0: |
| 42 | + lims[_axis_kwargs[channel['axis']].get('max')] = 0 |
| 43 | + |
| 44 | + elif channel['dtype'] == 'temporal': |
| 45 | + # determine limits |
| 46 | + if 'domain' in scale: |
| 47 | + domain = _convert_to_mpl_date(scale['domain']) |
| 48 | + lims[_axis_kwargs[channel['axis']].get('min')] = domain[0] |
| 49 | + lims[_axis_kwargs[channel['axis']].get('max')] = domain[1] |
| 50 | + elif 'type' in scale and scale['type'] != 'time': |
| 51 | + lims = _set_scale_type(channel, scale) |
| 52 | + |
| 53 | + else: |
| 54 | + raise NotImplementedError # Ordinal and Nominal go here? |
| 55 | + |
| 56 | + # set the limits |
| 57 | + if channel['axis'] == 'x': |
| 58 | + channel['ax'].set_xlim(**lims) |
| 59 | + else: |
| 60 | + channel['ax'].set_ylim(**lims) |
| 61 | + |
| 62 | + |
| 63 | +def _set_scale_type(channel, scale): |
| 64 | + """If the scale is non-linear, change the scale and return appropriate axis limits. |
| 65 | + The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear' |
| 66 | + and temporal defaults to 'time'. The 'utc' and 'sequential' scales are currently not supported. |
| 67 | +
|
| 68 | + Parameters |
| 69 | + ---------- |
| 70 | + channel : dict |
| 71 | + The mapping of the channel data and metadata |
| 72 | + scale : dict |
| 73 | + The mapping of the scale metadata and the scale data |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + lims : dict |
| 78 | + The axis limit mapped to the appropriate axis parameter for scales that change axis limit behavior |
| 79 | + """ |
| 80 | + lims = {} |
| 81 | + if scale['type'] == 'log': |
| 82 | + |
| 83 | + base = 10 # default base is 10 in altair |
| 84 | + if 'base' in scale: |
| 85 | + base = scale['base'] |
| 86 | + |
| 87 | + if channel['axis'] == 'x': |
| 88 | + channel['ax'].set_xscale('log', basex=base) |
| 89 | + # lower limit: round down to nearest major tick (using log base change rule) |
| 90 | + lims['left'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base)) |
| 91 | + else: # y-axis |
| 92 | + channel['ax'].set_yscale('log', basey=base) |
| 93 | + # lower limit: round down to nearest major tick (using log base change rule) |
| 94 | + lims['bottom'] = base**np.floor(np.log10(channel['data'].min())/np.log10(base)) |
| 95 | + |
| 96 | + elif scale['type'] == 'pow' or scale['type'] == 'sqrt': |
| 97 | + """The 'sqrt' scale is just the 'pow' scale with exponent = 0.5. |
| 98 | + When Matplotlib gets a power scale, the following should work: |
| 99 | + |
| 100 | + exponent = 2 # default exponent value for 'pow' scale |
| 101 | + if scale['type'] == 'sqrt': |
| 102 | + exponent = 0.5 |
| 103 | + elif 'exponent' in scale: |
| 104 | + exponent = scale['exponent'] |
| 105 | +
|
| 106 | + if channel['axis'] == 'x': |
| 107 | + channel['ax'].set_xscale('power_scale', exponent=exponent) |
| 108 | + else: # y-axis |
| 109 | + channel['ax'].set_yscale('power_scale', exponent=exponent) |
| 110 | + """ |
| 111 | + raise NotImplementedError |
| 112 | + |
| 113 | + elif scale['type'] == 'utc': |
| 114 | + raise NotImplementedError |
| 115 | + elif scale['type'] == 'sequential': |
| 116 | + raise NotImplementedError("sequential scales used primarily for continuous colors") |
| 117 | + else: |
| 118 | + raise NotImplementedError |
| 119 | + return lims |
| 120 | + |
| 121 | + |
| 122 | +def _set_tick_locator(channel, axis): |
| 123 | + """Set the tick locator if it needs to vary from the default locator |
| 124 | +
|
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + channel : dict |
| 128 | + The mapping of the channel data and metadata |
| 129 | + axis : dict |
| 130 | + The mapping of the axis metadata and the scale data |
| 131 | + """ |
| 132 | + current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis} |
| 133 | + if 'values' in axis: |
| 134 | + if channel['dtype'] == 'temporal': |
| 135 | + current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(_convert_to_mpl_date(axis.get('values')))) |
| 136 | + elif channel['dtype'] == 'quantitative': |
| 137 | + current_axis[channel['axis']].set_major_locator(ticker.FixedLocator(axis.get('values'))) |
| 138 | + else: |
| 139 | + raise NotImplementedError |
| 140 | + elif 'tickCount' in axis: |
| 141 | + current_axis[channel['axis']].set_major_locator( |
| 142 | + ticker.MaxNLocator(steps=[2, 5, 10], nbins=axis.get('tickCount')+1, min_n_ticks=axis.get('tickCount')) |
| 143 | + ) |
| 144 | + |
| 145 | + |
| 146 | +def _set_tick_formatter(channel, axis): |
| 147 | + """Set the tick formatter. |
| 148 | +
|
| 149 | +
|
| 150 | + Parameters |
| 151 | + ---------- |
| 152 | + channel : dict |
| 153 | + The mapping of the channel data and metadata |
| 154 | + axis : dict |
| 155 | + The mapping of the axis metadata and the scale data |
| 156 | +
|
| 157 | + Notes |
| 158 | + ----- |
| 159 | + For quantitative formatting, Matplotlib does not support some format strings that Altair supports. |
| 160 | + Matplotlib only supports format strings as used by str.format(). |
| 161 | +
|
| 162 | + For formatting of temporal data, Matplotlib does not support some format strings that Altair supports (%L, %Q, %s). |
| 163 | + Matplotlib only supports datetime.strftime formatting for dates. |
| 164 | + """ |
| 165 | + current_axis = {'x': channel['ax'].xaxis, 'y': channel['ax'].yaxis} |
| 166 | + format_str = '' |
| 167 | + |
| 168 | + if 'format' in axis: |
| 169 | + format_str = axis['format'] |
| 170 | + |
| 171 | + if channel['dtype'] == 'temporal': |
| 172 | + if not format_str: |
| 173 | + format_str = '%b %d, %Y' |
| 174 | + |
| 175 | + current_axis[channel['axis']].set_major_formatter(mdates.DateFormatter(format_str)) # May fail silently |
| 176 | + |
| 177 | + elif channel['dtype'] == 'quantitative': |
| 178 | + if format_str: |
| 179 | + current_axis[channel['axis']].set_major_formatter(ticker.StrMethodFormatter('{x:' + format_str + '}')) |
| 180 | + |
| 181 | + # Verify that the format string is valid for Matplotlib and exit nicely if not. |
| 182 | + try: |
| 183 | + current_axis[channel['axis']].get_major_formatter().__call__(1) |
| 184 | + except ValueError: |
| 185 | + raise ValueError("Matplotlib only supports format strings as used by `str.format()`." |
| 186 | + "Some format strings that work in Altair may not work in Matplotlib." |
| 187 | + "Please use a different format string.") |
| 188 | + else: |
| 189 | + raise NotImplementedError # Nominal and Ordinal go here |
| 190 | + |
| 191 | + |
| 192 | +def _set_label_angle(channel, axis): |
| 193 | + """Set the label angle. TODO: handle axis.labelAngle from Altair |
| 194 | +
|
| 195 | + Parameters |
| 196 | + ---------- |
| 197 | + channel : dict |
| 198 | + The mapping of the channel data and metadata |
| 199 | + axis : dict |
| 200 | + The mapping of the axis metadata and the scale data |
| 201 | + """ |
| 202 | + if channel['dtype'] == 'temporal' and channel['axis'] == 'x': |
| 203 | + for label in channel['ax'].get_xticklabels(): |
| 204 | + # Rotate the labels on the x-axis so they don't run into each other. |
| 205 | + label.set_rotation(30) |
| 206 | + label.set_ha('right') |
| 207 | + |
| 208 | + |
| 209 | +def convert_axis(ax, chart): |
| 210 | + """Convert elements of the altair chart to Matplotlib axis properties |
| 211 | +
|
| 212 | + Parameters |
| 213 | + ---------- |
| 214 | + ax |
| 215 | + The Matplotlib axis to be modified |
| 216 | + chart |
| 217 | + The Altair chart |
| 218 | + """ |
| 219 | + |
| 220 | + for channel in chart.to_dict()['encoding']: |
| 221 | + if channel in ['x', 'y']: |
| 222 | + chart_info = {'ax': ax, 'axis': channel, |
| 223 | + 'data': _locate_channel_data(chart, channel), |
| 224 | + 'dtype': _locate_channel_dtype(chart, channel)} |
| 225 | + if chart_info['dtype'] == 'temporal': |
| 226 | + chart_info['data'] = _convert_to_mpl_date(chart_info['data']) |
| 227 | + |
| 228 | + scale_info = _locate_channel_scale(chart, channel) |
| 229 | + axis_info = _locate_channel_axis(chart, channel) |
| 230 | + |
| 231 | + _set_limits(chart_info, scale_info) |
| 232 | + _set_tick_locator(chart_info, axis_info) |
| 233 | + _set_tick_formatter(chart_info, axis_info) |
| 234 | + _set_label_angle(chart_info, axis_info) |
0 commit comments