Skip to content

Commit 8f4c183

Browse files
committed
FEAT: implemented animate keyword in Array.plot()
1 parent 206d809 commit 8f4c183

File tree

3 files changed

+117
-40
lines changed

3 files changed

+117
-40
lines changed

doc/source/changes/version_0_35.rst.inc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ Backward incompatible changes
2121
New features
2222
^^^^^^^^^^^^
2323

24+
* Array.plot now has an ´animate´ argument to produce animated plots. The
25+
argument takes an axis (it also supports several axes but that is rarely
26+
useful) and will create an animation, with one image per label of that axis.
27+
For example,
28+
29+
>>> arr.plot.bar(animate='year')
30+
31+
will create an animated bar plot with one frame per year.
32+
2433
* added a feature (see the :ref:`miscellaneous section <misc>` for details). It works on :ref:`api-axis` and
2534
:ref:`api-group` objects.
2635

larray/core/array.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7197,6 +7197,28 @@ def plot(self) -> PlotObject:
71977197
- if a tuple of Axis (or int or str), stack each combination of labels of those axes.
71987198
- True is equivalent to all axes (not already used in other arguments) except the last.
71997199
Defaults to False in line and bar plots, and True in area plot.
7200+
animate : Axis, int, str or tuple, optional
7201+
Make an animated plot.
7202+
- if an Axis (or int or str), animate that axis (create one image per label on that axis).
7203+
One would usually use a time-related axis.
7204+
- if a tuple of Axis (or int or str), animate each combination of labels of those axes.
7205+
Defaults to None.
7206+
anim_params: dict, optional
7207+
Optional parameters to control how animations are saved to file.
7208+
- writer : str, optional
7209+
Backend to use. Defaults to 'pillow' for images (.gif .png and
7210+
.tiff), 'ffmpeg' otherwise.
7211+
- fps : int, optional
7212+
Animation frame rate (per second). Defaults to 5.
7213+
- metadata : dict, optional
7214+
Dictionary of metadata to include in the output file.
7215+
Some keys that may be of use include: title, artist, genre,
7216+
subject, copyright, srcform, comment. Defaults to {}.
7217+
- bitrate : int, optional
7218+
The bitrate of the movie, in kilobits per second. Higher values
7219+
means higher quality movies, but increase the file size.
7220+
A value of -1 lets the underlying movie encoder select the
7221+
bitrate.
72007222
**kwargs : keywords
72017223
Options to pass to matplotlib plotting method
72027224
@@ -7238,6 +7260,11 @@ def plot(self) -> PlotObject:
72387260
72397261
>>> arr.plot.bar(stack='gender')
72407262
7263+
An animated bar chart (with two bars). We set explicit y bounds via ylim so that the
7264+
same boundaries are used for the whole animation.
7265+
7266+
>>> arr.plot.bar(animate='year', ylim=(0, 22)) # doctest: +SKIP
7267+
72417268
Create a figure containing 2 x 2 graphs
72427269
72437270
>>> # see matplotlib.pyplot.subplots documentation for more details

larray/core/plot.py

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from pathlib import Path
12
import warnings
23

34
import numpy as np
45
import pandas as pd
56

6-
from larray import IGroup, Axis, AxisCollection, Group
7+
from larray.core.abstractbases import ABCArray
8+
from larray.core.axis import Axis, AxisCollection
9+
from larray.core.group import Group, IGroup
710
from larray.util.misc import deprecate_kwarg
811

912

@@ -19,7 +22,7 @@ def __init__(self, array):
1922
self.array = array
2023

2124
@staticmethod
22-
def _handle_x_y_axes(axes, x, y, subplots):
25+
def _handle_x_y_axes(axes, animate, subplots, x, y):
2326
label_axis = None
2427

2528
if np.isscalar(x) and x not in axes:
@@ -37,15 +40,21 @@ def _handle_x_y_axes(axes, x, y, subplots):
3740
def handle_axes_arg(avail_axes, arg):
3841
if arg is not None:
3942
arg = avail_axes[arg]
43+
avail_axes = avail_axes - arg
4044
if isinstance(arg, Axis):
4145
arg = AxisCollection([arg])
42-
avail_axes = avail_axes - arg
4346
return avail_axes, arg
4447

48+
available_axes = axes
49+
if animate:
50+
available_axes, animate_axes = handle_axes_arg(available_axes, animate)
51+
else:
52+
animate_axes = AxisCollection()
53+
4554
if label_axis is not None:
46-
available_axes = axes - label_axis
55+
available_axes = available_axes - label_axis
4756
else:
48-
available_axes, x = handle_axes_arg(axes, x)
57+
available_axes, x = handle_axes_arg(available_axes, x)
4958
available_axes, y = handle_axes_arg(available_axes, y)
5059

5160
if subplots is True:
@@ -85,7 +94,7 @@ def handle_axes_arg(avail_axes, arg):
8594
assert isinstance(subplot_axes, AxisCollection)
8695
assert y is None
8796

88-
return subplot_axes, series_axes, x, y
97+
return animate_axes, subplot_axes, x, y, series_axes
8998

9099
@staticmethod
91100
def _to_pd_obj(array):
@@ -172,7 +181,7 @@ def _plot_array(array, *args, x=None, y=None, series=None, _x_axes_last=False, *
172181
@deprecate_kwarg('stacked', 'stack')
173182
def __call__(self, x=None, y=None, ax=None, subplots=False, layout=None, figsize=None,
174183
sharex=None, sharey=False, tight_layout=None, constrained_layout=None, title=None, legend=None,
175-
**kwargs):
184+
animate=None, filepath=None, **kwargs):
176185
from matplotlib import pyplot as plt
177186

178187
array = self.array
@@ -190,47 +199,56 @@ def __call__(self, x=None, y=None, ax=None, subplots=False, layout=None, figsize
190199
"stack=axis_name instead", FutureWarning)
191200
kwargs['stacked'] = True
192201

193-
subplot_axes, series_axes, x, y = PlotObject._handle_x_y_axes(array.axes, x, y, subplots)
202+
animate_axes, subplot_axes, x, y, series_axes = PlotObject._handle_x_y_axes(array.axes, animate, subplots, x, y)
194203

195204
if constrained_layout is None and tight_layout is None:
196205
constrained_layout = True
197206

198-
if subplots:
199-
if ax is not None:
200-
raise ValueError("ax cannot be used in combination with subplots argument")
207+
if ax is None:
201208
fig = plt.figure(figsize=figsize, tight_layout=tight_layout, constrained_layout=constrained_layout)
202209

203-
num_subplots = subplot_axes.size
204-
if layout is None:
205-
subplots_shape = subplot_axes.shape
206-
if len(subplots_shape) > 2:
207-
# default to last axis horizontal, other axes combined vertically
208-
layout = np.prod(subplots_shape[:-1]), subplots_shape[-1]
209-
else:
210-
layout = subplot_axes.shape
210+
if subplots:
211+
if layout is None:
212+
subplots_shape = subplot_axes.shape
213+
if len(subplots_shape) > 2:
214+
# default to last axis horizontal, other axes combined vertically
215+
layout = np.prod(subplots_shape[:-1]), subplots_shape[-1]
216+
else:
217+
layout = subplot_axes.shape
218+
if sharex is None:
219+
sharex = True
220+
ax = fig.subplots(*layout, sharex=sharex, sharey=sharey)
221+
else:
222+
ax = fig.add_subplot()
211223

212-
if sharex is None:
213-
sharex = True
214-
ax = fig.subplots(*layout, sharex=sharex, sharey=sharey)
215-
# it is easier to always work with a flat array
216-
flat_ax = ax.flat
217-
# remove blank plot(s) at the end, if any
218-
if len(flat_ax) > num_subplots:
219-
for plot_ax in flat_ax[num_subplots:]:
220-
plot_ax.remove()
221-
# this not strictly necessary but is cleaner in case we reuse flax_ax
222-
flat_ax = flat_ax[:num_subplots]
223-
if title is not None:
224-
fig.suptitle(title)
225-
for i, (ndkey, subarr) in enumerate(array.items(subplot_axes)):
226-
title = ' '.join(str(ak) for ak in ndkey)
227-
self._plot_array(subarr, x=x, y=y, series=series_axes, ax=flat_ax[i], legend=False, title=title,
228-
**kwargs)
224+
if animate:
225+
import matplotlib.animation as animation
226+
227+
def run(t):
228+
if subplots:
229+
for subplot_ax in ax.flat:
230+
subplot_ax.clear()
231+
else:
232+
ax.clear()
233+
self._plot_many(array[t], ax, kwargs, series_axes, subplot_axes, title, x, y)
234+
# TODO: add support for interpolation between frames/labels
235+
# see https://github.com/julkaar9/pynimate for inspiration
236+
ani = animation.FuncAnimation(fig, run, frames=animate_axes.iter_labels())
237+
if not isinstance(filepath, Path):
238+
filepath = Path(filepath)
239+
print(f"Writing animation to {filepath} ...", end=' ', flush=True)
240+
if '.htm' in filepath.suffix:
241+
filepath.write_text(f'<html>{ani.to_html5_video()}</html>', encoding='utf8')
242+
else:
243+
# writer = self.writer
244+
# if writer is None:
245+
writer = 'pillow' if filepath.suffix == '.gif' else 'ffmpeg'
246+
fps = 5
247+
metadata = None
248+
bitrate = None
249+
ani.save(filepath, writer=writer, fps=fps, metadata=metadata, bitrate=bitrate)
229250
else:
230-
if ax is None:
231-
fig = plt.figure(figsize=figsize, tight_layout=tight_layout, constrained_layout=constrained_layout)
232-
ax = fig.subplots(1, 1)
233-
self._plot_array(array, x=x, y=y, series=series_axes, ax=ax, legend=False, title=title, **kwargs)
251+
self._plot_many(array, ax, kwargs, series_axes, subplot_axes, title, x, y)
234252

235253
if legend or legend is None:
236254
first_ax = ax.flat[0] if subplots else ax
@@ -251,6 +269,29 @@ def __call__(self, x=None, y=None, ax=None, subplots=False, layout=None, figsize
251269
legend_parent.legend(handles, labels, **legend_kwargs)
252270
return ax
253271

272+
def _plot_many(self, array, ax, kwargs, series_axes, subplot_axes, title, x, y):
273+
if len(subplot_axes):
274+
num_subplots = subplot_axes.size
275+
if not isinstance(ax, (np.ndarray, ABCArray)) or ax.size < num_subplots:
276+
raise ValueError(f"ax argument value is not compatible with subplot axes ({subplot_axes})")
277+
# it is easier to always work with a flat array
278+
flat_ax = ax.flat
279+
if title is not None:
280+
fig = flat_ax[0].figure
281+
fig.suptitle(title)
282+
# remove blank plot(s) at the end, if any
283+
if len(flat_ax) > num_subplots:
284+
for plot_ax in flat_ax[num_subplots:]:
285+
plot_ax.remove()
286+
# this not strictly necessary but is cleaner in case we reuse flat_ax
287+
flat_ax = flat_ax[:num_subplots]
288+
for i, (ndkey, subarr) in enumerate(array.items(subplot_axes)):
289+
subplot_title = ' '.join(str(ak) for ak in ndkey)
290+
self._plot_array(subarr, x=x, y=y, series=series_axes, ax=flat_ax[i], legend=False, title=subplot_title,
291+
**kwargs)
292+
else:
293+
self._plot_array(array, x=x, y=y, series=series_axes, ax=ax, legend=False, title=title, **kwargs)
294+
254295
@deprecate_kwarg('stacked', 'stack')
255296
@_use_pandas_plot_docstring
256297
def line(self, x=None, y=None, **kwds):

0 commit comments

Comments
 (0)