diff --git a/brian2tools/plotting/morphology.py b/brian2tools/plotting/morphology.py index dce55518..632dca0a 100644 --- a/brian2tools/plotting/morphology.py +++ b/brian2tools/plotting/morphology.py @@ -25,13 +25,19 @@ def _plot_morphology2D(morpho, axes, colors, voltage_colormap, show_diameter=False, show_compartments=True, color_counter=0): - if values is not None: - # Determine colors based on compartment values - normed_values = value_norm(values[morpho.indices[:]]) - colors = voltage_colormap(normed_values) - color = colors[0] + compartment_count = len(morpho.x) + + if values is None: + compartment_colors = [colors[color_counter % len(colors)]] * compartment_count else: - color = colors[color_counter % len(colors)] + try: + section_values = values[morpho.indices[:]] + except (IndexError, TypeError): + # Keep scalar behavior: one value means one color everywhere. + section_values = np.repeat(values, compartment_count) + normed_values = value_norm(section_values) + compartment_colors = voltage_colormap(normed_values) + color = compartment_colors[0] if isinstance(morpho, Soma): x, y = float(morpho.x[0]/um), float(morpho.y[0]/um) @@ -46,25 +52,35 @@ def _plot_morphology2D(morpho, axes, colors, if show_diameter: coords_2d = coords[:, :2] directions = np.diff(coords_2d, axis=0) - orthogonal = np.vstack([-directions[:, 1], directions[:, 0]]) - orthogonal = np.vstack([orthogonal.T, orthogonal[:, -1:].T]) - radius = np.hstack([morpho.start_diameter[0]/um/2, - morpho.end_diameter/um/2]) + orthogonal = np.vstack([-directions[:, 1], directions[:, 0]]).T orthogonal /= np.sqrt(np.sum(orthogonal**2, axis=1))[:, np.newaxis] - points = np.vstack([coords_2d + orthogonal*radius[:, np.newaxis], - (coords_2d - orthogonal*radius[:, np.newaxis])[::-1]]) - patch = Polygon(points, color=color) - axes.add_patch(patch) - + start_radius = morpho.start_diameter/um/2 + end_radius = morpho.end_diameter/um/2 + for idx, color in enumerate(compartment_colors): + start_point = coords_2d[idx] + end_point = coords_2d[idx + 1] + ortho = orthogonal[idx] + points = np.vstack([start_point + ortho*start_radius[idx], + end_point + ortho*end_radius[idx], + end_point - ortho*end_radius[idx], + start_point - ortho*start_radius[idx]]) + patch = Polygon(points, color=color) + axes.add_patch(patch) else: - axes.plot(coords[:, 0], coords[:, 1], color=color, lw=2) + for idx, color in enumerate(compartment_colors): + axes.plot(coords[idx:idx + 2, 0], coords[idx:idx + 2, 1], + color=color, lw=2) if show_compartments: # dots at the center of the compartments if show_diameter: color = 'black' - axes.plot(morpho.x/um, morpho.y/um, '.', color=color, - mec='none', alpha=0.75) + axes.plot(morpho.x/um, morpho.y/um, '.', color=color, + mec='none', alpha=0.75) + else: + axes.scatter(morpho.x/um, morpho.y/um, + c=compartment_colors, + marker='.', edgecolors='none', alpha=0.75) for child in morpho.children: _plot_morphology2D(child, axes=axes, diff --git a/brian2tools/tests/test_plotting.py b/brian2tools/tests/test_plotting.py index e1ec6a29..90b170a3 100644 --- a/brian2tools/tests/test_plotting.py +++ b/brian2tools/tests/test_plotting.py @@ -199,6 +199,27 @@ def test_plot_morphology_values(): plot_3d=False) +def test_plot_morphology_values_per_compartment_2d(): + set_device('runtime') + morpho = Soma(diameter=20*um) + morpho.axon = Cylinder(diameter=2*um, n=3, length=30*um) + morpho = morpho.generate_coordinates() + + # one value for the soma and three different values for the axon compartments + values = np.array([0., 1., 2., 3.]) + ax = plot_morphology(morpho, values=values, plot_3d=False, + show_compartments=False, show_diameter=False) + + # For the axon (n=3) we expect one plotted line segment per compartment. + section_lines = [line for line in ax.lines if line.get_linewidth() == 2] + assert len(section_lines) == 3 + + # Compartment values differ, therefore at least two colors should differ. + section_colors = [tuple(line.get_color()) for line in section_lines] + assert len(set(section_colors)) > 1 + plt.close() + + if __name__ == '__main__': test_plot_monitors() test_plot_synapses()