11import matplotlib .dates as mdates
22import matplotlib .ticker as ticker
33import numpy as np
4- from ._data import _locate_channel_data , _locate_channel_dtype , _locate_channel_scale , _locate_channel_axis , _convert_to_mpl_date
4+ from ._data import _convert_to_mpl_date
55
66
7- def _set_limits (channel , scale ):
7+ def _set_limits (channel , mark , ax ):
88 """Set the axis limits on the Matplotlib axis
99
1010 Parameters
1111 ----------
12- channel : dict
13- The mapping of the channel data and metadata
14- scale : dict
15- The mapping of the scale metadata and the scale data
12+ channel : parse_chart.ChannelMetadata
13+ The channel data and metadata
14+ mark : str
15+ The chart's mark
16+ ax : matplotlib.axes
1617 """
1718
1819 _axis_kwargs = {
@@ -22,136 +23,142 @@ def _set_limits(channel, scale):
2223
2324 lims = {}
2425
25- if channel [ 'dtype' ] == 'quantitative' :
26+ if channel . type == 'quantitative' :
2627 # determine limits
27- if 'domain' in scale : # domain takes precedence over zero in Altair
28- if scale ['domain' ] == 'unaggregated' :
28+ if 'domain' in channel . scale : # domain takes precedence over zero in Altair
29+ if channel . scale ['domain' ] == 'unaggregated' :
2930 raise NotImplementedError
3031 else :
31- lims [_axis_kwargs [channel [ 'axis' ]] .get ('min' )] = scale ['domain' ][0 ]
32- lims [_axis_kwargs [channel [ 'axis' ]] .get ('max' )] = scale ['domain' ][1 ]
33- elif 'type' in scale and scale ['type' ] != 'linear' :
34- lims = _set_scale_type (channel , scale )
32+ lims [_axis_kwargs [channel . name ] .get ('min' )] = channel . scale ['domain' ][0 ]
33+ lims [_axis_kwargs [channel . name ] .get ('max' )] = channel . scale ['domain' ][1 ]
34+ elif 'type' in channel . scale and channel . scale ['type' ] != 'linear' :
35+ lims = _set_scale_type (channel , ax )
3536 else :
36- # Check that a positive minimum is zero if zero is True:
37- if ('zero' not in scale or scale ['zero' ] == True ) and min (channel ['data' ]) > 0 :
38- lims [_axis_kwargs [channel ['axis' ]].get ('min' )] = 0 # quantitative sets min to be 0 by default
37+ # Include zero on the axis (or not).
38+ # In Altair, scale.zero defaults to False unless the data is unbinned quantitative.
39+ if mark == 'line' and channel .name == 'x' :
40+ # Contrary to documentation, Altair defaults to scale.zero=False for the x-axis on line graphs.
41+ # Pass to skip.
42+ pass
43+ else :
44+ # Check that a positive minimum is zero if scale.zero is True:
45+ if ('zero' not in channel .scale or channel .scale ['zero' ] == True ) and min (channel .data ) > 0 :
46+ lims [_axis_kwargs [channel .name ].get ('min' )] = 0 # quantitative sets min to be 0 by default
3947
40- # Check that a negative maximum is zero if zero is True:
41- if ('zero' not in scale or scale ['zero' ] == True ) and max (channel [ ' data' ] ) < 0 :
42- lims [_axis_kwargs [channel [ 'axis' ] ].get ('max' )] = 0
48+ # Check that a negative maximum is zero if scale. zero is True:
49+ if ('zero' not in channel . scale or channel . scale ['zero' ] == True ) and max (channel . data ) < 0 :
50+ lims [_axis_kwargs [channel . name ].get ('max' )] = 0
4351
44- elif channel [ 'dtype' ] == 'temporal' :
52+ elif channel . type == 'temporal' :
4553 # determine limits
46- if 'domain' in scale :
47- domain = _convert_to_mpl_date (scale ['domain' ])
48- lims [_axis_kwargs [channel [ 'axis' ] ].get ('min' )] = domain [0 ]
49- lims [_axis_kwargs [channel [ 'axis' ] ].get ('max' )] = domain [1 ]
50- elif 'type' in scale and scale ['type' ] != 'time' :
51- lims = _set_scale_type (channel , scale )
54+ if 'domain' in channel . scale :
55+ domain = _convert_to_mpl_date (channel . scale ['domain' ])
56+ lims [_axis_kwargs [channel . name ].get ('min' )] = domain [0 ]
57+ lims [_axis_kwargs [channel . name ].get ('max' )] = domain [1 ]
58+ elif 'type' in channel . scale and channel . scale ['type' ] != 'time' :
59+ lims = _set_scale_type (channel , channel . scale )
5260
5361 else :
5462 raise NotImplementedError # Ordinal and Nominal go here?
5563
5664 # set the limits
57- if channel [ 'axis' ] == 'x' :
58- channel [ 'ax' ] .set_xlim (** lims )
65+ if channel . name == 'x' :
66+ ax .set_xlim (** lims )
5967 else :
60- channel [ 'ax' ] .set_ylim (** lims )
68+ ax .set_ylim (** lims )
6169
6270
63- def _set_scale_type (channel , scale ):
71+ def _set_scale_type (channel , ax ):
6472 """If the scale is non-linear, change the scale and return appropriate axis limits.
6573 The 'linear' and 'time' scale types are not included here because quantitative defaults to 'linear'
6674 and temporal defaults to 'time'. The 'utc' and 'sequential' scales are currently not supported.
6775
6876 Parameters
6977 ----------
70- channel : dict
71- The mapping of the channel data and metadata
72- scale : dict
73- The mapping of the scale metadata and the scale data
78+ channel : parse_chart.ChannelMetadata
79+ The channel data and metadata
80+ ax : matplotlib.axes
7481
7582 Returns
7683 -------
7784 lims : dict
7885 The axis limit mapped to the appropriate axis parameter for scales that change axis limit behavior
7986 """
8087 lims = {}
81- if scale ['type' ] == 'log' :
88+ if channel . scale ['type' ] == 'log' :
8289
8390 base = 10 # default base is 10 in altair
84- if 'base' in scale :
85- base = scale ['base' ]
91+ if 'base' in channel . scale :
92+ base = channel . scale ['base' ]
8693
87- if channel [ 'axis' ] == 'x' :
88- channel [ 'ax' ] .set_xscale ('log' , basex = base )
94+ if channel . name == 'x' :
95+ ax .set_xscale ('log' , basex = base )
8996 # lower limit: round down to nearest major tick (using log base change rule)
90- lims ['left' ] = base ** np .floor (np .log10 (channel [ ' data' ] .min ())/ np .log10 (base ))
97+ lims ['left' ] = base ** np .floor (np .log10 (channel . data .min ())/ np .log10 (base ))
9198 else : # y-axis
92- channel [ 'ax' ] .set_yscale ('log' , basey = base )
99+ ax .set_yscale ('log' , basey = base )
93100 # lower limit: round down to nearest major tick (using log base change rule)
94- lims ['bottom' ] = base ** np .floor (np .log10 (channel [ ' data' ] .min ())/ np .log10 (base ))
101+ lims ['bottom' ] = base ** np .floor (np .log10 (channel . data .min ())/ np .log10 (base ))
95102
96- elif scale ['type' ] == 'pow' or scale ['type' ] == 'sqrt' :
103+ elif channel . scale ['type' ] == 'pow' or channel . scale ['type' ] == 'sqrt' :
97104 """The 'sqrt' scale is just the 'pow' scale with exponent = 0.5.
98105 When Matplotlib gets a power scale, the following should work:
99106
100107 exponent = 2 # default exponent value for 'pow' scale
101- if scale['type'] == 'sqrt':
108+ if channel. scale['type'] == 'sqrt':
102109 exponent = 0.5
103- elif 'exponent' in scale:
104- exponent = scale['exponent']
110+ elif 'exponent' in channel. scale:
111+ exponent = channel. scale['exponent']
105112
106- if channel['axis'] == 'x':
107- channel['ax'] .set_xscale('power_scale', exponent=exponent)
113+ if channel.name == 'x':
114+ ax .set_xscale('power_scale', exponent=exponent)
108115 else: # y-axis
109- channel['ax'] .set_yscale('power_scale', exponent=exponent)
116+ ax .set_yscale('power_scale', exponent=exponent)
110117 """
111118 raise NotImplementedError
112119
113- elif scale ['type' ] == 'utc' :
120+ elif channel . scale ['type' ] == 'utc' :
114121 raise NotImplementedError
115- elif scale ['type' ] == 'sequential' :
122+ elif channel . scale ['type' ] == 'sequential' :
116123 raise NotImplementedError ("sequential scales used primarily for continuous colors" )
117124 else :
118125 raise NotImplementedError
119126 return lims
120127
121128
122- def _set_tick_locator (channel , axis ):
129+ def _set_tick_locator (channel , ax ):
123130 """Set the tick locator if it needs to vary from the default locator
124131
125132 Parameters
126133 ----------
127- channel : dict
128- The mapping of the channel data and metadata
129- axis : dict
134+ channel : parse_chart.ChannelMetadata
135+ The channel data and metadata
136+ ax : matplotlib.axes
130137 The mapping of the axis metadata and the scale data
131138 """
132- current_axis = {'x' : channel [ 'ax' ] .xaxis , 'y' : channel [ 'ax' ] .yaxis }
133- if 'values' in axis :
134- if channel [ 'dtype' ] == 'temporal' :
135- current_axis [channel [ 'axis' ]] .set_major_locator (ticker .FixedLocator (_convert_to_mpl_date (axis .get ('values' ))))
136- elif channel [ 'dtype' ] == 'quantitative' :
137- current_axis [channel [ 'axis' ]] .set_major_locator (ticker .FixedLocator (axis .get ('values' )))
139+ current_axis = {'x' : ax .xaxis , 'y' : ax .yaxis }
140+ if 'values' in channel . axis :
141+ if channel . type == 'temporal' :
142+ current_axis [channel . name ] .set_major_locator (ticker .FixedLocator (_convert_to_mpl_date (channel . axis .get ('values' ))))
143+ elif channel . type == 'quantitative' :
144+ current_axis [channel . name ] .set_major_locator (ticker .FixedLocator (channel . axis .get ('values' )))
138145 else :
139146 raise NotImplementedError
140- elif 'tickCount' in axis :
141- current_axis [channel [ 'axis' ] ].set_major_locator (
142- ticker .MaxNLocator (steps = [2 , 5 , 10 ], nbins = axis .get ('tickCount' )+ 1 , min_n_ticks = axis .get ('tickCount' ))
147+ elif 'tickCount' in channel . axis :
148+ current_axis [channel . name ].set_major_locator (
149+ ticker .MaxNLocator (steps = [2 , 5 , 10 ], nbins = channel . axis .get ('tickCount' )+ 1 , min_n_ticks = channel . axis .get ('tickCount' ))
143150 )
144151
145152
146- def _set_tick_formatter (channel , axis ):
153+ def _set_tick_formatter (channel , ax ):
147154 """Set the tick formatter.
148155
149156
150157 Parameters
151158 ----------
152- channel : dict
153- The mapping of the channel data and metadata
154- axis : dict
159+ channel : parse_chart.ChannelMetadata
160+ The channel data and metadata
161+ ax : matplotlib.axes
155162 The mapping of the axis metadata and the scale data
156163
157164 Notes
@@ -162,25 +169,22 @@ def _set_tick_formatter(channel, axis):
162169 For formatting of temporal data, Matplotlib does not support some format strings that Altair supports (%L, %Q, %s).
163170 Matplotlib only supports datetime.strftime formatting for dates.
164171 """
165- current_axis = {'x' : channel ['ax' ].xaxis , 'y' : channel ['ax' ].yaxis }
166- format_str = ''
167-
168- if 'format' in axis :
169- format_str = axis ['format' ]
172+ current_axis = {'x' : ax .xaxis , 'y' : ax .yaxis }
173+ format_str = channel .axis .get ('format' , '' )
170174
171- if channel [ 'dtype' ] == 'temporal' :
175+ if channel . type == 'temporal' :
172176 if not format_str :
173177 format_str = '%b %d, %Y'
174178
175- current_axis [channel [ 'axis' ] ].set_major_formatter (mdates .DateFormatter (format_str )) # May fail silently
179+ current_axis [channel . name ].set_major_formatter (mdates .DateFormatter (format_str )) # May fail silently
176180
177- elif channel [ 'dtype' ] == 'quantitative' :
181+ elif channel . type == 'quantitative' :
178182 if format_str :
179- current_axis [channel [ 'axis' ] ].set_major_formatter (ticker .StrMethodFormatter ('{x:' + format_str + '}' ))
183+ current_axis [channel . name ].set_major_formatter (ticker .StrMethodFormatter ('{x:' + format_str + '}' ))
180184
181185 # Verify that the format string is valid for Matplotlib and exit nicely if not.
182186 try :
183- current_axis [channel [ 'axis' ] ].get_major_formatter ().__call__ (1 )
187+ current_axis [channel . name ].get_major_formatter ().__call__ (1 )
184188 except ValueError :
185189 raise ValueError ("Matplotlib only supports format strings as used by `str.format()`."
186190 "Some format strings that work in Altair may not work in Matplotlib."
@@ -189,18 +193,18 @@ def _set_tick_formatter(channel, axis):
189193 raise NotImplementedError # Nominal and Ordinal go here
190194
191195
192- def _set_label_angle (channel , axis ):
196+ def _set_label_angle (channel , ax ):
193197 """Set the label angle. TODO: handle axis.labelAngle from Altair
194198
195199 Parameters
196200 ----------
197- channel : dict
198- The mapping of the channel data and metadata
199- axis : dict
201+ channel : parse_chart.ChannelMetadata
202+ The channel data and metadata
203+ ax : matplotlib.axes
200204 The mapping of the axis metadata and the scale data
201205 """
202- if channel [ 'dtype' ] == 'temporal' and channel [ 'axis' ] == 'x' :
203- for label in channel [ 'ax' ] .get_xticklabels ():
206+ if channel . type == 'temporal' and channel . name == 'x' :
207+ for label in ax .get_xticklabels ():
204208 # Rotate the labels on the x-axis so they don't run into each other.
205209 label .set_rotation (30 )
206210 label .set_ha ('right' )
@@ -213,22 +217,12 @@ def convert_axis(ax, chart):
213217 ----------
214218 ax
215219 The Matplotlib axis to be modified
216- chart
217- The Altair chart
220+ chart : parse_chart.ChartMetadata
221+ The chart data and metadata
218222 """
219223
220- for channel in chart .to_dict ()['encoding' ]:
221- if channel in ['x' , 'y' ]:
222- chart_info = {'ax' : ax , 'axis' : channel ,
223- 'data' : _locate_channel_data (chart , channel ),
224- 'dtype' : _locate_channel_dtype (chart , channel )}
225- if chart_info ['dtype' ] == 'temporal' :
226- chart_info ['data' ] = _convert_to_mpl_date (chart_info ['data' ])
227-
228- scale_info = _locate_channel_scale (chart , channel )
229- axis_info = _locate_channel_axis (chart , channel )
230-
231- _set_limits (chart_info , scale_info )
232- _set_tick_locator (chart_info , axis_info )
233- _set_tick_formatter (chart_info , axis_info )
234- _set_label_angle (chart_info , axis_info )
224+ for channel in [chart .encoding ['x' ], chart .encoding ['y' ]]:
225+ _set_limits (channel , chart .mark , ax )
226+ _set_tick_locator (channel , ax )
227+ _set_tick_formatter (channel , ax )
228+ _set_label_angle (channel , ax )
0 commit comments