1+ from pathlib import Path
12import warnings
23
34import numpy as np
45import pandas as pd
56
6- from larray import IGroup , Axis , AxisCollection , Group
7+ from larray .core .abstractbases import ABCArray
8+ from larray .core .axis import Axis , AxisCollection
9+ from larray .core .group import Group , IGroup
710from larray .util .misc import deprecate_kwarg
811
912
@@ -19,7 +22,7 @@ def __init__(self, array):
1922 self .array = array
2023
2124 @staticmethod
22- def _handle_x_y_axes (axes , x , y , subplots ):
25+ def _handle_x_y_axes (axes , animate , subplots , x , y ):
2326 label_axis = None
2427
2528 if np .isscalar (x ) and x not in axes :
@@ -37,15 +40,21 @@ def _handle_x_y_axes(axes, x, y, subplots):
3740 def handle_axes_arg (avail_axes , arg ):
3841 if arg is not None :
3942 arg = avail_axes [arg ]
43+ avail_axes = avail_axes - arg
4044 if isinstance (arg , Axis ):
4145 arg = AxisCollection ([arg ])
42- avail_axes = avail_axes - arg
4346 return avail_axes , arg
4447
48+ available_axes = axes
49+ if animate :
50+ available_axes , animate_axes = handle_axes_arg (available_axes , animate )
51+ else :
52+ animate_axes = AxisCollection ()
53+
4554 if label_axis is not None :
46- available_axes = axes - label_axis
55+ available_axes = available_axes - label_axis
4756 else :
48- available_axes , x = handle_axes_arg (axes , x )
57+ available_axes , x = handle_axes_arg (available_axes , x )
4958 available_axes , y = handle_axes_arg (available_axes , y )
5059
5160 if subplots is True :
@@ -85,7 +94,7 @@ def handle_axes_arg(avail_axes, arg):
8594 assert isinstance (subplot_axes , AxisCollection )
8695 assert y is None
8796
88- return subplot_axes , series_axes , x , y
97+ return animate_axes , subplot_axes , x , y , series_axes
8998
9099 @staticmethod
91100 def _to_pd_obj (array ):
@@ -172,7 +181,7 @@ def _plot_array(array, *args, x=None, y=None, series=None, _x_axes_last=False, *
172181 @deprecate_kwarg ('stacked' , 'stack' )
173182 def __call__ (self , x = None , y = None , ax = None , subplots = False , layout = None , figsize = None ,
174183 sharex = None , sharey = False , tight_layout = None , constrained_layout = None , title = None , legend = None ,
175- ** kwargs ):
184+ animate = None , filepath = None , ** kwargs ):
176185 from matplotlib import pyplot as plt
177186
178187 array = self .array
@@ -190,47 +199,56 @@ def __call__(self, x=None, y=None, ax=None, subplots=False, layout=None, figsize
190199 "stack=axis_name instead" , FutureWarning )
191200 kwargs ['stacked' ] = True
192201
193- subplot_axes , series_axes , x , y = PlotObject ._handle_x_y_axes (array .axes , x , y , subplots )
202+ animate_axes , subplot_axes , x , y , series_axes = PlotObject ._handle_x_y_axes (array .axes , animate , subplots , x , y )
194203
195204 if constrained_layout is None and tight_layout is None :
196205 constrained_layout = True
197206
198- if subplots :
199- if ax is not None :
200- raise ValueError ("ax cannot be used in combination with subplots argument" )
207+ if ax is None :
201208 fig = plt .figure (figsize = figsize , tight_layout = tight_layout , constrained_layout = constrained_layout )
202209
203- num_subplots = subplot_axes .size
204- if layout is None :
205- subplots_shape = subplot_axes .shape
206- if len (subplots_shape ) > 2 :
207- # default to last axis horizontal, other axes combined vertically
208- layout = np .prod (subplots_shape [:- 1 ]), subplots_shape [- 1 ]
209- else :
210- layout = subplot_axes .shape
210+ if subplots :
211+ if layout is None :
212+ subplots_shape = subplot_axes .shape
213+ if len (subplots_shape ) > 2 :
214+ # default to last axis horizontal, other axes combined vertically
215+ layout = np .prod (subplots_shape [:- 1 ]), subplots_shape [- 1 ]
216+ else :
217+ layout = subplot_axes .shape
218+ if sharex is None :
219+ sharex = True
220+ ax = fig .subplots (* layout , sharex = sharex , sharey = sharey )
221+ else :
222+ ax = fig .add_subplot ()
211223
212- if sharex is None :
213- sharex = True
214- ax = fig .subplots (* layout , sharex = sharex , sharey = sharey )
215- # it is easier to always work with a flat array
216- flat_ax = ax .flat
217- # remove blank plot(s) at the end, if any
218- if len (flat_ax ) > num_subplots :
219- for plot_ax in flat_ax [num_subplots :]:
220- plot_ax .remove ()
221- # this not strictly necessary but is cleaner in case we reuse flax_ax
222- flat_ax = flat_ax [:num_subplots ]
223- if title is not None :
224- fig .suptitle (title )
225- for i , (ndkey , subarr ) in enumerate (array .items (subplot_axes )):
226- title = ' ' .join (str (ak ) for ak in ndkey )
227- self ._plot_array (subarr , x = x , y = y , series = series_axes , ax = flat_ax [i ], legend = False , title = title ,
228- ** kwargs )
224+ if animate :
225+ import matplotlib .animation as animation
226+
227+ def run (t ):
228+ if subplots :
229+ for subplot_ax in ax .flat :
230+ subplot_ax .clear ()
231+ else :
232+ ax .clear ()
233+ self ._plot_many (array [t ], ax , kwargs , series_axes , subplot_axes , title , x , y )
234+ # TODO: add support for interpolation between frames/labels
235+ # see https://github.com/julkaar9/pynimate for inspiration
236+ ani = animation .FuncAnimation (fig , run , frames = animate_axes .iter_labels ())
237+ if not isinstance (filepath , Path ):
238+ filepath = Path (filepath )
239+ print (f"Writing animation to { filepath } ..." , end = ' ' , flush = True )
240+ if '.htm' in filepath .suffix :
241+ filepath .write_text (f'<html>{ ani .to_html5_video ()} </html>' , encoding = 'utf8' )
242+ else :
243+ # writer = self.writer
244+ # if writer is None:
245+ writer = 'pillow' if filepath .suffix == '.gif' else 'ffmpeg'
246+ fps = 5
247+ metadata = None
248+ bitrate = None
249+ ani .save (filepath , writer = writer , fps = fps , metadata = metadata , bitrate = bitrate )
229250 else :
230- if ax is None :
231- fig = plt .figure (figsize = figsize , tight_layout = tight_layout , constrained_layout = constrained_layout )
232- ax = fig .subplots (1 , 1 )
233- self ._plot_array (array , x = x , y = y , series = series_axes , ax = ax , legend = False , title = title , ** kwargs )
251+ self ._plot_many (array , ax , kwargs , series_axes , subplot_axes , title , x , y )
234252
235253 if legend or legend is None :
236254 first_ax = ax .flat [0 ] if subplots else ax
@@ -251,6 +269,29 @@ def __call__(self, x=None, y=None, ax=None, subplots=False, layout=None, figsize
251269 legend_parent .legend (handles , labels , ** legend_kwargs )
252270 return ax
253271
272+ def _plot_many (self , array , ax , kwargs , series_axes , subplot_axes , title , x , y ):
273+ if len (subplot_axes ):
274+ num_subplots = subplot_axes .size
275+ if not isinstance (ax , (np .ndarray , ABCArray )) or ax .size < num_subplots :
276+ raise ValueError (f"ax argument value is not compatible with subplot axes ({ subplot_axes } )" )
277+ # it is easier to always work with a flat array
278+ flat_ax = ax .flat
279+ if title is not None :
280+ fig = flat_ax [0 ].figure
281+ fig .suptitle (title )
282+ # remove blank plot(s) at the end, if any
283+ if len (flat_ax ) > num_subplots :
284+ for plot_ax in flat_ax [num_subplots :]:
285+ plot_ax .remove ()
286+ # this not strictly necessary but is cleaner in case we reuse flat_ax
287+ flat_ax = flat_ax [:num_subplots ]
288+ for i , (ndkey , subarr ) in enumerate (array .items (subplot_axes )):
289+ subplot_title = ' ' .join (str (ak ) for ak in ndkey )
290+ self ._plot_array (subarr , x = x , y = y , series = series_axes , ax = flat_ax [i ], legend = False , title = subplot_title ,
291+ ** kwargs )
292+ else :
293+ self ._plot_array (array , x = x , y = y , series = series_axes , ax = ax , legend = False , title = title , ** kwargs )
294+
254295 @deprecate_kwarg ('stacked' , 'stack' )
255296 @_use_pandas_plot_docstring
256297 def line (self , x = None , y = None , ** kwds ):
0 commit comments