Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 87 additions & 26 deletions brian2tools/plotting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,71 @@
logger = get_logger(__name__)


def _resolve_var_kwds(kwds, var_name, values):
'''
Build a per-variable keyword dict, resolving ``var_name`` and ``var_unit``
that may be plain values (applied to every variable) or dicts keyed by
variable name.
'''
kwds_var = {k: v for k, v in kwds.items()
if k not in ('var_name', 'var_unit')}

vn = kwds.get('var_name')
if isinstance(vn, dict):
kwds_var['var_name'] = vn.get(var_name, var_name)
elif vn is not None:
kwds_var['var_name'] = vn
else:
kwds_var['var_name'] = var_name

vu = kwds.get('var_unit')
if isinstance(vu, dict):
unit = vu.get(var_name)
if unit is not None:
kwds_var['var_unit'] = unit
elif isinstance(values, Quantity):
kwds_var['var_unit'] = _get_best_unit(values)
elif vu is not None:
kwds_var['var_unit'] = vu
elif isinstance(values, Quantity):
kwds_var['var_unit'] = _get_best_unit(values)

return kwds_var


def _plot_state_variables(brian_obj, record_variables, axes, **kwds):
'''
Plot one or more recorded state variables. When more than one variable
is present a column of subplots sharing the time axis is created
automatically (unless the caller supplies a matching array of axes).
'''
n_vars = len(record_variables)

if n_vars == 1:
var_name = record_variables[0]
values = getattr(brian_obj, var_name).T
kwds_var = _resolve_var_kwds(kwds, var_name, values)
return plot_state(brian_obj.t, values, axes=axes, **kwds_var)

if axes is None:
fig, axes_arr = plt.subplots(n_vars, 1, sharex=True)
else:
axes_arr = np.asarray(axes).ravel()
if len(axes_arr) != n_vars:
raise TypeError(
f"If multiple variables are recorded, 'axes' must be an "
f"array-like of Axes with length {n_vars} (got {len(axes_arr)})."
)

ret_axes = []
for ax, var_name in zip(axes_arr, record_variables):
values = getattr(brian_obj, var_name).T
kwds_var = _resolve_var_kwds(kwds, var_name, values)
ret_axes.append(
plot_state(brian_obj.t, values, axes=ax, **kwds_var))
return ret_axes


def _setup_axes_matplotlib(axes):
'''
Helper function to create new figures/axes for matplotlib, depending on
Expand Down Expand Up @@ -56,50 +121,46 @@ def brian_plot(brian_obj,
change. This function is therefore mostly meant as a quick and easy way to
plot an object, for full control use one of the specific plotting functions.

When a `~brian2.monitors.statemonitor.StateMonitor` that records several
variables is given, a column of subplots sharing the time axis is created
automatically.

Parameters
----------
brian_obj : object
The Brian object to plot.
axes : `~matplotlib.axes.Axes`, optional
axes : `~matplotlib.axes.Axes` or array-like of Axes, optional
The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
``None`` which means that a new `~matplotlib.axes.Axes` will be
created for the plot.
created for the plot. For a multi-variable
`~brian2.monitors.statemonitor.StateMonitor`, pass an array-like of
`~matplotlib.axes.Axes` with one entry per recorded variable.
kwds : dict, optional
Any additional keywords command will be handed over to matplotlib's
`~matplotlib.axes.Axes.plot` command. This can be used to set plot
properties such as the ``color``.

For multi-variable `~brian2.monitors.statemonitor.StateMonitor`
objects, ``var_name`` and ``var_unit`` may be dictionaries keyed by
variable name, e.g.
``var_name={'v': 'membrane potential', 'I': 'input current'}``.

Returns
-------
axes : `~matplotlib.axes.Axes`
The `~matplotlib.axes.Axes` instance that was used for plotting. This
object allows to modify the plot further, e.g. by setting the plotted
range, the axis labels, the plot title, etc.
axes : `~matplotlib.axes.Axes` or list of `~matplotlib.axes.Axes`
The `~matplotlib.axes.Axes` instance(s) used for plotting. A list is
returned when multiple state variables are plotted.
'''
if isinstance(brian_obj, SpikeMonitor):
return plot_raster(brian_obj.i, brian_obj.t, axes=axes, **kwds)
elif isinstance(brian_obj, StateMonitor):
if len(brian_obj.record_variables) != 1:
raise TypeError('brian_plot only works for a StateMonitor that '
'records a single variable.')
values = getattr(brian_obj, brian_obj.record_variables[0]).T
if 'var_name' not in kwds:
kwds['var_name'] = brian_obj.record_variables[0]
if 'var_unit' not in kwds and isinstance(values, Quantity):
kwds['var_unit'] = _get_best_unit(values)
return plot_state(brian_obj.t, values, axes=axes, **kwds)
return _plot_state_variables(brian_obj,
brian_obj.record_variables,
axes, **kwds)
elif isinstance(brian_obj, StateMonitorView):
monitor = brian_obj.monitor
if len(monitor.record_variables) != 1:
raise TypeError('brian_plot only works for a StateMonitor that '
'records a single variable.')
var_name = monitor.record_variables[0]
values = getattr(brian_obj, var_name).T
if 'var_name' not in kwds:
kwds['var_name'] = var_name
if 'var_unit' not in kwds and isinstance(values, Quantity):
kwds['var_unit'] = _get_best_unit(values)
return plot_state(brian_obj.t, values, axes=axes, **kwds)
return _plot_state_variables(brian_obj,
brian_obj.monitor.record_variables,
axes, **kwds)
elif isinstance(brian_obj, PopulationRateMonitor):
smooth_rate = brian_obj.smooth_rate(width=1*ms)
if 'rate_unit' not in kwds:
Expand Down
41 changes: 41 additions & 0 deletions brian2tools/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,46 @@ def test_plot_monitors():
assert isinstance(ax, matplotlib.axes.Axes)


def test_plot_multivar_monitors():
set_device('runtime')
group = NeuronGroup(10, '''dv/dt = -v/(10*ms) : volt
dw/dt = -w/(10*ms) : volt''',
threshold='False', reset='', method='linear')
group.v = np.linspace(0, 1, 10)*mV
group.w = np.linspace(0, 0.5, 10)*mV
state_mon = StateMonitor(group, ['v', 'w'], record=[3, 5])
run(10*ms)

# Multi-variable StateMonitor should return a list of Axes
axes = brian_plot(state_mon)
assert isinstance(axes, list)
assert len(axes) == 2
for ax in axes:
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()

# Pre-created axes of matching length should work
fig, ax_arr = plt.subplots(2, 1, sharex=True)
axes = brian_plot(state_mon, axes=ax_arr)
assert isinstance(axes, list)
assert len(axes) == 2
plt.close()

# Wrong number of axes should raise TypeError
fig, bad_axes = plt.subplots(3, 1)
with pytest.raises(TypeError):
brian_plot(state_mon, axes=bad_axes)
plt.close()

# StateMonitorView of a multi-variable monitor
axes = brian_plot(state_mon[3])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually realized that an earlier comment of mine was incorrect – I thought e.g. state_mon.v would return a StateMonitorView (in the same way that for a NeuronGroup group, the access group.v would return a VariableView), but you are right that in StateMonitor, a "view" is used when you index into the group. I wonder whether it would make sense to change something on the Brian side so that you could call e.g. brian_plot(state_mon.v) (which currently does not work, since state_mon.v is a simple Quantity without information about the variable name). But this is outside the scope of this PR, of course.

assert isinstance(axes, list)
assert len(axes) == 2
for ax in axes:
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()


def test_plot_synapses():
set_device('runtime')
group = NeuronGroup(10, 'dv/dt = -v/(10*ms) : volt', threshold='False',
Expand Down Expand Up @@ -222,5 +262,6 @@ def test_plot_morphology_values_per_compartment_2d():

if __name__ == '__main__':
test_plot_monitors()
test_plot_multivar_monitors()
test_plot_synapses()
test_plot_morphology()
Loading