From 6930c402d38b66ceb941e601830fee67edae0944 Mon Sep 17 00:00:00 2001 From: Utkarsha Dunde Date: Sat, 14 Mar 2026 17:44:28 +0530 Subject: [PATCH 1/3] Feature: properly route multi-variable StateMonitors to subplots --- brian2tools/plotting/base.py | 79 +++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/brian2tools/plotting/base.py b/brian2tools/plotting/base.py index 28bb385d..a2ab4b4d 100644 --- a/brian2tools/plotting/base.py +++ b/brian2tools/plotting/base.py @@ -79,27 +79,68 @@ def brian_plot(brian_obj, if isinstance(brian_obj, SpikeMonitor): return plot_raster(brian_obj.i, brian_obj.t, axes=axes, **kwds) elif isinstance(brian_obj, StateMonitor): - if len(brian_obj.record_variables) != 1: - raise TypeError('brian_plot only works for a StateMonitor that ' - 'records a single variable.') - values = getattr(brian_obj, brian_obj.record_variables[0]).T - if 'var_name' not in kwds: - kwds['var_name'] = brian_obj.record_variables[0] - if 'var_unit' not in kwds and isinstance(values, Quantity): - kwds['var_unit'] = _get_best_unit(values) - return plot_state(brian_obj.t, values, axes=axes, **kwds) + n_vars = len(brian_obj.record_variables) + if n_vars == 1: + var_name = brian_obj.record_variables[0] + values = getattr(brian_obj, var_name).T + if 'var_name' not in kwds: + kwds['var_name'] = var_name + if 'var_unit' not in kwds and isinstance(values, Quantity): + kwds['var_unit'] = _get_best_unit(values) + return plot_state(brian_obj.t, values, axes=axes, **kwds) + else: + if axes is None: + fig, axes_arr = plt.subplots(n_vars, 1, sharex=True) + else: + try: + axes_arr = np.asarray(axes).ravel() + assert len(axes_arr) == n_vars + except Exception: + raise TypeError("If multiple variables are recorded, 'axes' must " + "be an iterable of matching length.") + ret_axes = [] + for i, var_name in enumerate(brian_obj.record_variables): + values = getattr(brian_obj, var_name).T + kwds_var = kwds.copy() + if 'var_name' not in kwds_var: + kwds_var['var_name'] = var_name + if 'var_unit' not in kwds_var and isinstance(values, Quantity): + kwds_var['var_unit'] = _get_best_unit(values) + ax = plot_state(brian_obj.t, values, axes=axes_arr[i], **kwds_var) + ret_axes.append(ax) + return ret_axes elif isinstance(brian_obj, StateMonitorView): monitor = brian_obj.monitor - if len(monitor.record_variables) != 1: - raise TypeError('brian_plot only works for a StateMonitor that ' - 'records a single variable.') - var_name = monitor.record_variables[0] - values = getattr(brian_obj, var_name).T - if 'var_name' not in kwds: - kwds['var_name'] = var_name - if 'var_unit' not in kwds and isinstance(values, Quantity): - kwds['var_unit'] = _get_best_unit(values) - return plot_state(brian_obj.t, values, axes=axes, **kwds) + n_vars = len(monitor.record_variables) + if n_vars == 1: + var_name = monitor.record_variables[0] + values = getattr(brian_obj, var_name).T + if 'var_name' not in kwds: + kwds['var_name'] = var_name + if 'var_unit' not in kwds and isinstance(values, Quantity): + kwds['var_unit'] = _get_best_unit(values) + return plot_state(brian_obj.t, values, axes=axes, **kwds) + else: + if axes is None: + fig, axes_arr = plt.subplots(n_vars, 1, sharex=True) + else: + try: + axes_arr = np.asarray(axes).ravel() + assert len(axes_arr) == n_vars + except Exception: + raise TypeError("If multiple variables are recorded, 'axes' must " + "be an iterable of matching length.") + ret_axes = [] + for i, var_name in enumerate(monitor.record_variables): + values = getattr(brian_obj, var_name).T + kwds_var = kwds.copy() + if 'var_name' not in kwds_var: + kwds_var['var_name'] = var_name + if 'var_unit' not in kwds_var and isinstance(values, Quantity): + kwds_var['var_unit'] = _get_best_unit(values) + ax = plot_state(brian_obj.t, values, axes=axes_arr[i], **kwds_var) + ret_axes.append(ax) + return ret_axes elif isinstance(brian_obj, PopulationRateMonitor): smooth_rate = brian_obj.smooth_rate(width=1*ms) if 'rate_unit' not in kwds: From df4450f1dacce4cd397da8bbd06b2232cc323077 Mon Sep 17 00:00:00 2001 From: Utkarsha Dunde Date: Mon, 30 Mar 2026 01:58:50 +0530 Subject: [PATCH 2/3] Address review comments: extract helpers, support dict kwargs, add tests & docsAddress review comments: extract helpers, support dict kwargs, add tests & docs --- brian2tools/plotting/base.py | 154 ++++++++++++++++------------- brian2tools/tests/test_plotting.py | 41 ++++++++ docs_sphinx/user/plotting.rst | 22 +++++ 3 files changed, 150 insertions(+), 67 deletions(-) diff --git a/brian2tools/plotting/base.py b/brian2tools/plotting/base.py index 749769da..2c6581a9 100644 --- a/brian2tools/plotting/base.py +++ b/brian2tools/plotting/base.py @@ -23,6 +23,71 @@ logger = get_logger(__name__) +def _resolve_var_kwds(kwds, var_name, values): + ''' + Build a per-variable keyword dict, resolving ``var_name`` and ``var_unit`` + that may be plain values (applied to every variable) or dicts keyed by + variable name. + ''' + kwds_var = {k: v for k, v in kwds.items() + if k not in ('var_name', 'var_unit')} + + vn = kwds.get('var_name') + if isinstance(vn, dict): + kwds_var['var_name'] = vn.get(var_name, var_name) + elif vn is not None: + kwds_var['var_name'] = vn + else: + kwds_var['var_name'] = var_name + + vu = kwds.get('var_unit') + if isinstance(vu, dict): + unit = vu.get(var_name) + if unit is not None: + kwds_var['var_unit'] = unit + elif isinstance(values, Quantity): + kwds_var['var_unit'] = _get_best_unit(values) + elif vu is not None: + kwds_var['var_unit'] = vu + elif isinstance(values, Quantity): + kwds_var['var_unit'] = _get_best_unit(values) + + return kwds_var + + +def _plot_state_variables(brian_obj, record_variables, axes, **kwds): + ''' + Plot one or more recorded state variables. When more than one variable + is present a column of subplots sharing the time axis is created + automatically (unless the caller supplies a matching array of axes). + ''' + n_vars = len(record_variables) + + if n_vars == 1: + var_name = record_variables[0] + values = getattr(brian_obj, var_name).T + kwds_var = _resolve_var_kwds(kwds, var_name, values) + return plot_state(brian_obj.t, values, axes=axes, **kwds_var) + + if axes is None: + fig, axes_arr = plt.subplots(n_vars, 1, sharex=True) + else: + axes_arr = np.asarray(axes).ravel() + if len(axes_arr) != n_vars: + raise TypeError( + "If multiple variables are recorded, 'axes' must be an " + "array-like of Axes with length %d (got %d)." + % (n_vars, len(axes_arr))) + + ret_axes = [] + for i, var_name in enumerate(record_variables): + values = getattr(brian_obj, var_name).T + kwds_var = _resolve_var_kwds(kwds, var_name, values) + ret_axes.append( + plot_state(brian_obj.t, values, axes=axes_arr[i], **kwds_var)) + return ret_axes + + def _setup_axes_matplotlib(axes): ''' Helper function to create new figures/axes for matplotlib, depending on @@ -56,91 +121,46 @@ def brian_plot(brian_obj, change. This function is therefore mostly meant as a quick and easy way to plot an object, for full control use one of the specific plotting functions. + When a `~brian2.monitors.statemonitor.StateMonitor` that records several + variables is given, a column of subplots sharing the time axis is created + automatically. + Parameters ---------- brian_obj : object The Brian object to plot. - axes : `~matplotlib.axes.Axes`, optional + axes : `~matplotlib.axes.Axes` or array-like of Axes, optional The `~matplotlib.axes.Axes` instance used for plotting. Defaults to ``None`` which means that a new `~matplotlib.axes.Axes` will be - created for the plot. + created for the plot. For a multi-variable + `~brian2.monitors.statemonitor.StateMonitor`, pass an array-like of + `~matplotlib.axes.Axes` with one entry per recorded variable. kwds : dict, optional Any additional keywords command will be handed over to matplotlib's `~matplotlib.axes.Axes.plot` command. This can be used to set plot properties such as the ``color``. + For multi-variable `~brian2.monitors.statemonitor.StateMonitor` + objects, ``var_name`` and ``var_unit`` may be dictionaries keyed by + variable name, e.g. + ``var_name={'v': 'membrane potential', 'I': 'input current'}``. + Returns ------- - axes : `~matplotlib.axes.Axes` - The `~matplotlib.axes.Axes` instance that was used for plotting. This - object allows to modify the plot further, e.g. by setting the plotted - range, the axis labels, the plot title, etc. + axes : `~matplotlib.axes.Axes` or list of `~matplotlib.axes.Axes` + The `~matplotlib.axes.Axes` instance(s) used for plotting. A list is + returned when multiple state variables are plotted. ''' if isinstance(brian_obj, SpikeMonitor): return plot_raster(brian_obj.i, brian_obj.t, axes=axes, **kwds) elif isinstance(brian_obj, StateMonitor): - n_vars = len(brian_obj.record_variables) - if n_vars == 1: - var_name = brian_obj.record_variables[0] - values = getattr(brian_obj, var_name).T - if 'var_name' not in kwds: - kwds['var_name'] = var_name - if 'var_unit' not in kwds and isinstance(values, Quantity): - kwds['var_unit'] = _get_best_unit(values) - return plot_state(brian_obj.t, values, axes=axes, **kwds) - else: - if axes is None: - fig, axes_arr = plt.subplots(n_vars, 1, sharex=True) - else: - try: - axes_arr = np.asarray(axes).ravel() - assert len(axes_arr) == n_vars - except Exception: - raise TypeError("If multiple variables are recorded, 'axes' must " - "be an iterable of matching length.") - ret_axes = [] - for i, var_name in enumerate(brian_obj.record_variables): - values = getattr(brian_obj, var_name).T - kwds_var = kwds.copy() - if 'var_name' not in kwds_var: - kwds_var['var_name'] = var_name - if 'var_unit' not in kwds_var and isinstance(values, Quantity): - kwds_var['var_unit'] = _get_best_unit(values) - ax = plot_state(brian_obj.t, values, axes=axes_arr[i], **kwds_var) - ret_axes.append(ax) - return ret_axes + return _plot_state_variables(brian_obj, + brian_obj.record_variables, + axes, **kwds) elif isinstance(brian_obj, StateMonitorView): - monitor = brian_obj.monitor - n_vars = len(monitor.record_variables) - if n_vars == 1: - var_name = monitor.record_variables[0] - values = getattr(brian_obj, var_name).T - if 'var_name' not in kwds: - kwds['var_name'] = var_name - if 'var_unit' not in kwds and isinstance(values, Quantity): - kwds['var_unit'] = _get_best_unit(values) - return plot_state(brian_obj.t, values, axes=axes, **kwds) - else: - if axes is None: - fig, axes_arr = plt.subplots(n_vars, 1, sharex=True) - else: - try: - axes_arr = np.asarray(axes).ravel() - assert len(axes_arr) == n_vars - except Exception: - raise TypeError("If multiple variables are recorded, 'axes' must " - "be an iterable of matching length.") - ret_axes = [] - for i, var_name in enumerate(monitor.record_variables): - values = getattr(brian_obj, var_name).T - kwds_var = kwds.copy() - if 'var_name' not in kwds_var: - kwds_var['var_name'] = var_name - if 'var_unit' not in kwds_var and isinstance(values, Quantity): - kwds_var['var_unit'] = _get_best_unit(values) - ax = plot_state(brian_obj.t, values, axes=axes_arr[i], **kwds_var) - ret_axes.append(ax) - return ret_axes + return _plot_state_variables(brian_obj, + brian_obj.monitor.record_variables, + axes, **kwds) elif isinstance(brian_obj, PopulationRateMonitor): smooth_rate = brian_obj.smooth_rate(width=1*ms) if 'rate_unit' not in kwds: diff --git a/brian2tools/tests/test_plotting.py b/brian2tools/tests/test_plotting.py index e1ec6a29..01ec37c8 100644 --- a/brian2tools/tests/test_plotting.py +++ b/brian2tools/tests/test_plotting.py @@ -50,6 +50,46 @@ def test_plot_monitors(): assert isinstance(ax, matplotlib.axes.Axes) +def test_plot_multivar_monitors(): + set_device('runtime') + group = NeuronGroup(10, '''dv/dt = -v/(10*ms) : volt + dw/dt = -w/(10*ms) : volt''', + threshold='False', reset='', method='linear') + group.v = np.linspace(0, 1, 10)*mV + group.w = np.linspace(0, 0.5, 10)*mV + state_mon = StateMonitor(group, ['v', 'w'], record=[3, 5]) + run(10*ms) + + # Multi-variable StateMonitor should return a list of Axes + axes = brian_plot(state_mon) + assert isinstance(axes, list) + assert len(axes) == 2 + for ax in axes: + assert isinstance(ax, matplotlib.axes.Axes) + plt.close() + + # Pre-created axes of matching length should work + fig, ax_arr = plt.subplots(2, 1, sharex=True) + axes = brian_plot(state_mon, axes=ax_arr) + assert isinstance(axes, list) + assert len(axes) == 2 + plt.close() + + # Wrong number of axes should raise TypeError + fig, bad_axes = plt.subplots(3, 1) + with pytest.raises(TypeError): + brian_plot(state_mon, axes=bad_axes) + plt.close() + + # StateMonitorView of a multi-variable monitor + axes = brian_plot(state_mon[3]) + assert isinstance(axes, list) + assert len(axes) == 2 + for ax in axes: + assert isinstance(ax, matplotlib.axes.Axes) + plt.close() + + def test_plot_synapses(): set_device('runtime') group = NeuronGroup(10, 'dv/dt = -v/(10*ms) : volt', threshold='False', @@ -201,5 +241,6 @@ def test_plot_morphology_values(): if __name__ == '__main__': test_plot_monitors() + test_plot_multivar_monitors() test_plot_synapses() test_plot_morphology() diff --git a/docs_sphinx/user/plotting.rst b/docs_sphinx/user/plotting.rst index 2e87c2b2..aa05f2d3 100644 --- a/docs_sphinx/user/plotting.rst +++ b/docs_sphinx/user/plotting.rst @@ -119,6 +119,28 @@ demonstrate the use of the returned `~matplotlib.axes.Axes` object to add a lege .. image:: ../images/plot_state.svg +Multiple state variables +~~~~~~~~~~~~~~~~~~~~~~~~ +If the `~brian2.monitors.statemonitor.StateMonitor` records several variables, +`~brian2tools.plotting.base.brian_plot` automatically creates stacked subplots sharing the time axis. +Using the same CUBA model as above, but recording all three state variables:: + + multi_mon = StateMonitor(P, ['v', 'ge', 'gi'], record=[0, 100, 1000]) + run(1 * second) + brian_plot(multi_mon) + +Custom display names and units can be provided per variable via dictionaries:: + + brian_plot(multi_mon, + var_name={'v': 'membrane potential', + 'ge': 'excitatory input', + 'gi': 'inhibitory input'}) + +You can also supply your own pre-created axes (one per variable):: + + fig, axes = plt.subplots(3, 1, sharex=True) + brian_plot(multi_mon, axes=axes) + Plotting synaptic connections and variables ------------------------------------------- For the following examples, we create synapses and synaptic weights according to "distances" (differences between the From f35d787470f00fda9dbf4049e14106052f2200a2 Mon Sep 17 00:00:00 2001 From: Utkarsha Dunde Date: Wed, 29 Apr 2026 00:54:17 +0530 Subject: [PATCH 3/3] Addressing review comments and adding the image. --- brian2tools/plotting/base.py | 10 +- .../images/brian_plot_multivar_state_mon.svg | 8384 +++++++++++++++++ docs_sphinx/user/plotting.rst | 2 + 3 files changed, 8391 insertions(+), 5 deletions(-) create mode 100644 docs_sphinx/images/brian_plot_multivar_state_mon.svg diff --git a/brian2tools/plotting/base.py b/brian2tools/plotting/base.py index 2c6581a9..0a8f8d22 100644 --- a/brian2tools/plotting/base.py +++ b/brian2tools/plotting/base.py @@ -75,16 +75,16 @@ def _plot_state_variables(brian_obj, record_variables, axes, **kwds): axes_arr = np.asarray(axes).ravel() if len(axes_arr) != n_vars: raise TypeError( - "If multiple variables are recorded, 'axes' must be an " - "array-like of Axes with length %d (got %d)." - % (n_vars, len(axes_arr))) + f"If multiple variables are recorded, 'axes' must be an " + f"array-like of Axes with length {n_vars} (got {len(axes_arr)})." + ) ret_axes = [] - for i, var_name in enumerate(record_variables): + for ax, var_name in zip(axes_arr, record_variables): values = getattr(brian_obj, var_name).T kwds_var = _resolve_var_kwds(kwds, var_name, values) ret_axes.append( - plot_state(brian_obj.t, values, axes=axes_arr[i], **kwds_var)) + plot_state(brian_obj.t, values, axes=ax, **kwds_var)) return ret_axes diff --git a/docs_sphinx/images/brian_plot_multivar_state_mon.svg b/docs_sphinx/images/brian_plot_multivar_state_mon.svg new file mode 100644 index 00000000..aad8be32 --- /dev/null +++ b/docs_sphinx/images/brian_plot_multivar_state_mon.svg @@ -0,0 +1,8384 @@ + + + + + + + + 2026-04-25T01:34:26.607317 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs_sphinx/user/plotting.rst b/docs_sphinx/user/plotting.rst index aa05f2d3..cbc23db8 100644 --- a/docs_sphinx/user/plotting.rst +++ b/docs_sphinx/user/plotting.rst @@ -129,6 +129,8 @@ Using the same CUBA model as above, but recording all three state variables:: run(1 * second) brian_plot(multi_mon) +.. image:: ../images/brian_plot_multivar_state_mon.svg + Custom display names and units can be provided per variable via dictionaries:: brian_plot(multi_mon,