Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions brian2tools/plotting/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ def _plot_morphology2D(morpho, axes, colors,
voltage_colormap,
show_diameter=False, show_compartments=True,
color_counter=0):
if values is not None:
# Determine colors based on compartment values
normed_values = value_norm(values[morpho.indices[:]])
colors = voltage_colormap(normed_values)
color = colors[0]
compartment_count = len(morpho.x)

if values is None:
compartment_colors = [colors[color_counter % len(colors)]] * compartment_count
else:
color = colors[color_counter % len(colors)]
try:
section_values = values[morpho.indices[:]]
except (IndexError, TypeError):
# Keep scalar behavior: one value means one color everywhere.
section_values = np.repeat(values, compartment_count)
normed_values = value_norm(section_values)
compartment_colors = voltage_colormap(normed_values)
color = compartment_colors[0]

if isinstance(morpho, Soma):
x, y = float(morpho.x[0]/um), float(morpho.y[0]/um)
Expand All @@ -46,25 +52,35 @@ def _plot_morphology2D(morpho, axes, colors,
if show_diameter:
coords_2d = coords[:, :2]
directions = np.diff(coords_2d, axis=0)
orthogonal = np.vstack([-directions[:, 1], directions[:, 0]])
orthogonal = np.vstack([orthogonal.T, orthogonal[:, -1:].T])
radius = np.hstack([morpho.start_diameter[0]/um/2,
morpho.end_diameter/um/2])
orthogonal = np.vstack([-directions[:, 1], directions[:, 0]]).T
orthogonal /= np.sqrt(np.sum(orthogonal**2, axis=1))[:, np.newaxis]

points = np.vstack([coords_2d + orthogonal*radius[:, np.newaxis],
(coords_2d - orthogonal*radius[:, np.newaxis])[::-1]])
patch = Polygon(points, color=color)
axes.add_patch(patch)

start_radius = morpho.start_diameter/um/2
end_radius = morpho.end_diameter/um/2
for idx, color in enumerate(compartment_colors):
start_point = coords_2d[idx]
end_point = coords_2d[idx + 1]
ortho = orthogonal[idx]
points = np.vstack([start_point + ortho*start_radius[idx],
end_point + ortho*end_radius[idx],
end_point - ortho*end_radius[idx],
start_point - ortho*start_radius[idx]])
patch = Polygon(points, color=color)
axes.add_patch(patch)
else:
axes.plot(coords[:, 0], coords[:, 1], color=color, lw=2)
for idx, color in enumerate(compartment_colors):
axes.plot(coords[idx:idx + 2, 0], coords[idx:idx + 2, 1],
color=color, lw=2)
if show_compartments:
# dots at the center of the compartments
if show_diameter:
color = 'black'
axes.plot(morpho.x/um, morpho.y/um, '.', color=color,
mec='none', alpha=0.75)
axes.plot(morpho.x/um, morpho.y/um, '.', color=color,
mec='none', alpha=0.75)
else:
axes.scatter(morpho.x/um, morpho.y/um,
c=compartment_colors,
marker='.', edgecolors='none', alpha=0.75)

for child in morpho.children:
_plot_morphology2D(child, axes=axes,
Expand Down
21 changes: 21 additions & 0 deletions brian2tools/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,27 @@ def test_plot_morphology_values():
plot_3d=False)


def test_plot_morphology_values_per_compartment_2d():
set_device('runtime')
morpho = Soma(diameter=20*um)
morpho.axon = Cylinder(diameter=2*um, n=3, length=30*um)
morpho = morpho.generate_coordinates()

# one value for the soma and three different values for the axon compartments
values = np.array([0., 1., 2., 3.])
ax = plot_morphology(morpho, values=values, plot_3d=False,
show_compartments=False, show_diameter=False)

# For the axon (n=3) we expect one plotted line segment per compartment.
section_lines = [line for line in ax.lines if line.get_linewidth() == 2]
assert len(section_lines) == 3

# Compartment values differ, therefore at least two colors should differ.
section_colors = [tuple(line.get_color()) for line in section_lines]
assert len(set(section_colors)) > 1
plt.close()


if __name__ == '__main__':
test_plot_monitors()
test_plot_synapses()
Expand Down