@@ -103,8 +103,119 @@ def _to_pd_obj(array):
103103 else :
104104 return array .to_frame ()
105105
106+ @staticmethod
107+ def _plot_heat_map (array , x = None , y = None , numhaxes = 1 , axes_names = True , maxticks = 10 , ax = None ,
108+ # TODO: we *might* want to default to False for wildcard axes (for label axes, even
109+ # numeric ones, an inverted axis is more natural)
110+ # TODO: rename to topdown_yaxis or zero_top_yaxis or y0_top or whatever where
111+ # the name actually helps knowing the direction
112+ invert_yaxis = True ,
113+ x_ticks_top = True , colorbar = False , ** kwargs ):
114+ from larray .util .plot import MaxNMultipleWithOffsetLocator
115+
116+ assert ax is not None
117+
118+ # TODO: check if we should handle those here???
119+ kwargs .pop ('kind' )
120+ kwargs .pop ('legend' )
121+ # This is needed to support plotting using imshow (see below)
122+ if 'aspect' not in kwargs :
123+ kwargs ['aspect' ] = 'auto'
124+ if 'origin' not in kwargs :
125+ kwargs ['origin' ] = 'lower'
126+ title = kwargs .pop ('title' , None )
127+ if title is not None :
128+ ax .set_title (title )
129+ if array .ndim < 2 :
130+ array = array .expand (Axis (['' ], '' ))
131+
132+ # TODO: see how much of this is already handled in _plot_array
133+ axes = array .axes
134+ if x is None and y is None :
135+ x = axes [:- numhaxes ]
136+
137+ if y is None :
138+ y = array .axes - x
139+ else :
140+ if isinstance (y , str ):
141+ y = [y ]
142+ y = array .axes [y ]
143+
144+ if x is None :
145+ x = array .axes - y
146+ else :
147+ if isinstance (x , str ):
148+ x = [x ]
149+ x = array .axes [x ]
150+
151+ array = array .transpose (y + x ).combine_axes ([y , x ])
152+
153+ # block size is the size of the other (non-first) combined axes
154+ x_block_size = int (x [1 :].size )
155+ y_block_size = int (y [1 :].size )
156+ c = ax .imshow (array .data , ** kwargs )
157+
158+ # place major ticks in the middle of blocks so that labels are centered
159+ xlabels = x [0 ].labels
160+ ylabels = y [0 ].labels
161+
162+ def format_x_tick (tick_val , tick_pos ):
163+ label_index = int (tick_val ) // x_block_size
164+ return xlabels [label_index ] if label_index < len (xlabels ) else '<bad tick>'
165+
166+ def format_y_tick (tick_val , tick_pos ):
167+ label_index = int (tick_val ) // y_block_size
168+ return ylabels [label_index ] if label_index < len (ylabels ) else '<bad tick>'
169+
170+ # A FuncFormatter is created automatically.
171+ ax .xaxis .set_major_formatter (format_x_tick )
172+ ax .yaxis .set_major_formatter (format_y_tick )
173+
174+ if invert_yaxis :
175+ ax .invert_yaxis ()
176+
177+ # offset=0 because imshow has some kind of builtin offset
178+ x_locator = MaxNMultipleWithOffsetLocator (min (maxticks , len (xlabels )), offset = 0 )
179+ y_locator = MaxNMultipleWithOffsetLocator (min (maxticks , len (ylabels )), offset = 0 )
180+ ax .xaxis .set_major_locator (x_locator )
181+ ax .yaxis .set_major_locator (y_locator )
182+
183+ if x_ticks_top :
184+ ax .xaxis .tick_top ()
185+ ax .xaxis .set_label_position ('top' )
186+
187+ # enable grid lines for minor ticks on axes when we have several "levels" for that axis
188+ if len (x ) > 1 :
189+ # place minor ticks for grid lines between each block on the main axis
190+ ax .set_xticks (np .arange (x_block_size , x .size , x_block_size ), minor = True )
191+ ax .grid (True , axis = 'x' , which = 'minor' )
192+ # hide all ticks on x axis
193+ ax .tick_params (axis = 'x' , which = 'both' , bottom = False , top = False )
194+
195+ if len (y ) > 1 :
196+ ax .set_yticks (np .arange (y_block_size , y .size , y_block_size ), minor = True )
197+ ax .grid (True , axis = 'y' , which = 'minor' )
198+ # hide all ticks on y axis
199+ ax .tick_params (axis = 'y' , which = 'both' , left = False , right = False )
200+
201+ # set axes names
202+ if axes_names :
203+ ax .set_xlabel ('\n ' .join (x .names ))
204+ ax .set_ylabel ('\n ' .join (y .names ))
205+
206+ if colorbar :
207+ ax .figure .colorbar (c )
208+ return ax
209+
106210 @staticmethod
107211 def _plot_array (array , * args , x = None , y = None , series = None , _x_axes_last = False , ** kwargs ):
212+ kind = kwargs .get ('kind' , 'line' )
213+ if kind is None :
214+ kind = 'line'
215+ # heatmaps are special because they do not go via Pandas
216+ if kind == 'heatmap' :
217+ return PlotObject ._plot_heat_map (array , x = x , y = y , ** kwargs )
218+
108219 label_axis = None
109220 if array .ndim == 1 :
110221 pass
@@ -134,9 +245,6 @@ def _plot_array(array, *args, x=None, y=None, series=None, _x_axes_last=False, *
134245 # move label_axis last (it must be a dataframe column)
135246 array = array .transpose (..., label_axis )
136247
137- kind = kwargs .get ('kind' , 'line' )
138- if kind is None :
139- kind = 'line'
140248 lineplot = kind == 'line'
141249 # TODO: why don't we handle all line plots this way?
142250 if lineplot and label_axis is not None and series is not None and len (series ) > 0 :
@@ -239,7 +347,8 @@ def run(t):
239347 def run (t ):
240348 ax .clear ()
241349 self ._plot_many (array [t ], ax , kwargs , series_axes , subplot_axes , title , x , y )
242- # TODO: add support for interpolation between frames/labels
350+ # TODO: add support for interpolation between frames/labels. Would be best to implement this via
351+ # a generic interpolation API in larray though.
243352 # see https://github.com/julkaar9/pynimate for inspiration
244353 ani = FuncAnimation (fig , run , frames = animate_axes .iter_labels ())
245354 else :
@@ -330,6 +439,8 @@ def _plot_many(self, array, ax, kwargs, series_axes, subplot_axes, title, x, y):
330439 plot_ax .remove ()
331440 # this not strictly necessary but is cleaner in case we reuse flat_ax
332441 flat_ax = flat_ax [:num_subplots ]
442+ if kwargs .get ('kind' ) == 'heatmap' and 'x_ticks_top' not in kwargs :
443+ kwargs ['x_ticks_top' ] = False
333444 for i , (ndkey , subarr ) in enumerate (array .items (subplot_axes )):
334445 subplot_title = ' ' .join (str (ak ) for ak in ndkey )
335446 self ._plot_array (subarr , x = x , y = y , series = series_axes , ax = flat_ax [i ], legend = False , title = subplot_title ,
@@ -362,6 +473,35 @@ def box(self, by=None, x=None, **kwds):
362473 ax .get_xaxis ().set_visible (False )
363474 return ax
364475
476+ def heatmap (self , x = None , y = None , ** kwds ):
477+ """plot an ND array as a heatmap.
478+
479+ By default, it uses the last array axis as the X axis and other array axes as Y axis (like the viewer table).
480+ Only the first axis in each "direction" will have its name and labels shown.
481+
482+ Parameters
483+ ----------
484+ arr : Array
485+ data to display.
486+ y_axes : int, str, Axis, tuple or AxisCollection, optional
487+ axis or axes to use on the Y axis. Defaults to all array axes except the last `numhaxes` ones.
488+ x_axes : int, str, Axis, tuple or AxisCollection, optional
489+ axis or axes to use on the X axis. Defaults to all array axes except `y_axes`.
490+ numhaxes : int, optional
491+ if x_axes and y_axes are not specified, use the last numhaxes as X axes. Defaults to 1.
492+ axes_names : bool, optional
493+ whether to show axes names. Defaults to True
494+ ax : matplotlib axes object, optional
495+ **kwargs
496+ any extra keyword argument is passed to Matplotlib imshow.
497+ Likely of interest are cmap, vmin, vmax or norm.
498+
499+ Returns
500+ -------
501+ matplotlib.AxesSubplot
502+ """
503+ return self (kind = 'heatmap' , x = x , y = y , ** kwds )
504+
365505 @_use_pandas_plot_docstring
366506 def hist (self , by = None , bins = 10 , y = None , ** kwds ):
367507 if y is None :
0 commit comments