Skip to content

Commit 959e6f1

Browse files
committed
FEAT: implemented heatmap plots
1 parent 35cc775 commit 959e6f1

File tree

4 files changed

+187
-4
lines changed

4 files changed

+187
-4
lines changed

doc/source/changes/version_0_35.rst.inc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ New features
3939
* implemented Array.plot `show` argument to display plots directly, without
4040
having to use the matplotlib API. This is the new default behavior.
4141

42+
* implemented a new kind of plot: `heatmap`. It can be used like this:
43+
44+
>>> arr.plot.heatmap()
45+
4246
* added a feature (see the :ref:`miscellaneous section <misc>` for details). It works on :ref:`api-axis` and
4347
:ref:`api-group` objects.
4448

larray/core/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7130,6 +7130,8 @@ def plot(self) -> PlotObject:
71307130
- 'pie' : pie plot
71317131
- 'scatter' : scatter plot (if array's dimensions >= 2)
71327132
- 'hexbin' : hexbin plot (if array's dimensions >= 2)
7133+
- 'heatmap': heatmap plot (if array's dimensions >= 2).
7134+
See Array.plot.heatmap for more details.
71337135
filepath : str or Path, default None
71347136
Save plot as a file at `filepath`. Defaults to None (do not save).
71357137
When saving the plot to a file, the function returns None. In other

larray/core/plot.py

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,119 @@ def _to_pd_obj(array):
103103
else:
104104
return array.to_frame()
105105

106+
@staticmethod
107+
def _plot_heat_map(array, x=None, y=None, numhaxes=1, axes_names=True, maxticks=10, ax=None,
108+
# TODO: we *might* want to default to False for wildcard axes (for label axes, even
109+
# numeric ones, an inverted axis is more natural)
110+
# TODO: rename to topdown_yaxis or zero_top_yaxis or y0_top or whatever where
111+
# the name actually helps knowing the direction
112+
invert_yaxis=True,
113+
x_ticks_top=True, colorbar=False, **kwargs):
114+
from larray.util.plot import MaxNMultipleWithOffsetLocator
115+
116+
assert ax is not None
117+
118+
# TODO: check if we should handle those here???
119+
kwargs.pop('kind')
120+
kwargs.pop('legend')
121+
# This is needed to support plotting using imshow (see below)
122+
if 'aspect' not in kwargs:
123+
kwargs['aspect'] = 'auto'
124+
if 'origin' not in kwargs:
125+
kwargs['origin'] = 'lower'
126+
title = kwargs.pop('title', None)
127+
if title is not None:
128+
ax.set_title(title)
129+
if array.ndim < 2:
130+
array = array.expand(Axis([''], ''))
131+
132+
# TODO: see how much of this is already handled in _plot_array
133+
axes = array.axes
134+
if x is None and y is None:
135+
x = axes[:-numhaxes]
136+
137+
if y is None:
138+
y = array.axes - x
139+
else:
140+
if isinstance(y, str):
141+
y = [y]
142+
y = array.axes[y]
143+
144+
if x is None:
145+
x = array.axes - y
146+
else:
147+
if isinstance(x, str):
148+
x = [x]
149+
x = array.axes[x]
150+
151+
array = array.transpose(y + x).combine_axes([y, x])
152+
153+
# block size is the size of the other (non-first) combined axes
154+
x_block_size = int(x[1:].size)
155+
y_block_size = int(y[1:].size)
156+
c = ax.imshow(array.data, **kwargs)
157+
158+
# place major ticks in the middle of blocks so that labels are centered
159+
xlabels = x[0].labels
160+
ylabels = y[0].labels
161+
162+
def format_x_tick(tick_val, tick_pos):
163+
label_index = int(tick_val) // x_block_size
164+
return xlabels[label_index] if label_index < len(xlabels) else '<bad tick>'
165+
166+
def format_y_tick(tick_val, tick_pos):
167+
label_index = int(tick_val) // y_block_size
168+
return ylabels[label_index] if label_index < len(ylabels) else '<bad tick>'
169+
170+
# A FuncFormatter is created automatically.
171+
ax.xaxis.set_major_formatter(format_x_tick)
172+
ax.yaxis.set_major_formatter(format_y_tick)
173+
174+
if invert_yaxis:
175+
ax.invert_yaxis()
176+
177+
# offset=0 because imshow has some kind of builtin offset
178+
x_locator = MaxNMultipleWithOffsetLocator(min(maxticks, len(xlabels)), offset=0)
179+
y_locator = MaxNMultipleWithOffsetLocator(min(maxticks, len(ylabels)), offset=0)
180+
ax.xaxis.set_major_locator(x_locator)
181+
ax.yaxis.set_major_locator(y_locator)
182+
183+
if x_ticks_top:
184+
ax.xaxis.tick_top()
185+
ax.xaxis.set_label_position('top')
186+
187+
# enable grid lines for minor ticks on axes when we have several "levels" for that axis
188+
if len(x) > 1:
189+
# place minor ticks for grid lines between each block on the main axis
190+
ax.set_xticks(np.arange(x_block_size, x.size, x_block_size), minor=True)
191+
ax.grid(True, axis='x', which='minor')
192+
# hide all ticks on x axis
193+
ax.tick_params(axis='x', which='both', bottom=False, top=False)
194+
195+
if len(y) > 1:
196+
ax.set_yticks(np.arange(y_block_size, y.size, y_block_size), minor=True)
197+
ax.grid(True, axis='y', which='minor')
198+
# hide all ticks on y axis
199+
ax.tick_params(axis='y', which='both', left=False, right=False)
200+
201+
# set axes names
202+
if axes_names:
203+
ax.set_xlabel('\n'.join(x.names))
204+
ax.set_ylabel('\n'.join(y.names))
205+
206+
if colorbar:
207+
ax.figure.colorbar(c)
208+
return ax
209+
106210
@staticmethod
107211
def _plot_array(array, *args, x=None, y=None, series=None, _x_axes_last=False, **kwargs):
212+
kind = kwargs.get('kind', 'line')
213+
if kind is None:
214+
kind = 'line'
215+
# heatmaps are special because they do not go via Pandas
216+
if kind == 'heatmap':
217+
return PlotObject._plot_heat_map(array, x=x, y=y, **kwargs)
218+
108219
label_axis = None
109220
if array.ndim == 1:
110221
pass
@@ -134,9 +245,6 @@ def _plot_array(array, *args, x=None, y=None, series=None, _x_axes_last=False, *
134245
# move label_axis last (it must be a dataframe column)
135246
array = array.transpose(..., label_axis)
136247

137-
kind = kwargs.get('kind', 'line')
138-
if kind is None:
139-
kind = 'line'
140248
lineplot = kind == 'line'
141249
# TODO: why don't we handle all line plots this way?
142250
if lineplot and label_axis is not None and series is not None and len(series) > 0:
@@ -239,7 +347,8 @@ def run(t):
239347
def run(t):
240348
ax.clear()
241349
self._plot_many(array[t], ax, kwargs, series_axes, subplot_axes, title, x, y)
242-
# TODO: add support for interpolation between frames/labels
350+
# TODO: add support for interpolation between frames/labels. Would be best to implement this via
351+
# a generic interpolation API in larray though.
243352
# see https://github.com/julkaar9/pynimate for inspiration
244353
ani = FuncAnimation(fig, run, frames=animate_axes.iter_labels())
245354
else:
@@ -330,6 +439,8 @@ def _plot_many(self, array, ax, kwargs, series_axes, subplot_axes, title, x, y):
330439
plot_ax.remove()
331440
# this not strictly necessary but is cleaner in case we reuse flat_ax
332441
flat_ax = flat_ax[:num_subplots]
442+
if kwargs.get('kind') == 'heatmap' and 'x_ticks_top' not in kwargs:
443+
kwargs['x_ticks_top'] = False
333444
for i, (ndkey, subarr) in enumerate(array.items(subplot_axes)):
334445
subplot_title = ' '.join(str(ak) for ak in ndkey)
335446
self._plot_array(subarr, x=x, y=y, series=series_axes, ax=flat_ax[i], legend=False, title=subplot_title,
@@ -362,6 +473,35 @@ def box(self, by=None, x=None, **kwds):
362473
ax.get_xaxis().set_visible(False)
363474
return ax
364475

476+
def heatmap(self, x=None, y=None, **kwds):
477+
"""plot an ND array as a heatmap.
478+
479+
By default, it uses the last array axis as the X axis and other array axes as Y axis (like the viewer table).
480+
Only the first axis in each "direction" will have its name and labels shown.
481+
482+
Parameters
483+
----------
484+
arr : Array
485+
data to display.
486+
y_axes : int, str, Axis, tuple or AxisCollection, optional
487+
axis or axes to use on the Y axis. Defaults to all array axes except the last `numhaxes` ones.
488+
x_axes : int, str, Axis, tuple or AxisCollection, optional
489+
axis or axes to use on the X axis. Defaults to all array axes except `y_axes`.
490+
numhaxes : int, optional
491+
if x_axes and y_axes are not specified, use the last numhaxes as X axes. Defaults to 1.
492+
axes_names : bool, optional
493+
whether to show axes names. Defaults to True
494+
ax : matplotlib axes object, optional
495+
**kwargs
496+
any extra keyword argument is passed to Matplotlib imshow.
497+
Likely of interest are cmap, vmin, vmax or norm.
498+
499+
Returns
500+
-------
501+
matplotlib.AxesSubplot
502+
"""
503+
return self(kind='heatmap', x=x, y=y, **kwds)
504+
365505
@_use_pandas_plot_docstring
366506
def hist(self, by=None, bins=10, y=None, **kwds):
367507
if y is None:

larray/util/plot.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
from matplotlib.ticker import MaxNLocator
3+
4+
5+
class MaxNMultipleWithOffsetLocator(MaxNLocator):
6+
def __init__(self, nbins=None, offset=0.5, **kwargs):
7+
super().__init__(nbins, **kwargs)
8+
self.offset = offset
9+
10+
def tick_values(self, vmin, vmax):
11+
# matplotlib calls them vmin and vmax but they are actually the limits and vmin can be > vmax
12+
invert = vmin > vmax
13+
if invert:
14+
vmin, vmax = vmax, vmin
15+
16+
max_desired_ticks = self._nbins
17+
# not + 1 because we place ticks in the middle
18+
num_ticks = vmax - vmin
19+
desired_numticks = min(num_ticks, max_desired_ticks)
20+
if desired_numticks < num_ticks:
21+
step = np.ceil(num_ticks / desired_numticks)
22+
else:
23+
step = 1
24+
vmin = int(vmin)
25+
vmax = int(vmax)
26+
# when we have an offset, we do not add 1 to vmax because we place ticks in the middle
27+
# (by adding the offset), and would result in the last "tick" being outside the limits
28+
stop = vmax + 1 if self.offset == 0 else vmax
29+
new_ticks = np.arange(vmin, stop, step)
30+
if invert:
31+
new_ticks = new_ticks[::-1]
32+
return new_ticks + self.offset
33+
34+
def __call__(self):
35+
"""Return the locations of the ticks."""
36+
vmin, vmax = self.axis.get_view_interval()
37+
return self.tick_values(vmin, vmax)

0 commit comments

Comments
 (0)