Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions brian2tools/plotting/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def _float_connection_matrix(sources, targets, values):
weights, delays, ...) in the form of a masked matrix (entries without value
are set to NaN and masked).

For source-target pairs with multiple synapses the displayed value is the
mean across all synapses between that pair.

Parameters
----------
sources : ndarray of int
Expand All @@ -65,9 +68,18 @@ def _float_connection_matrix(sources, targets, values):
matrix : ma.MaskedArray
The connection matrix, masked for NaN values
'''
full_matrix = np.ones((np.max(targets) - np.min(targets) + 1,
np.max(sources) - np.min(sources) + 1)) * np.nan
full_matrix[targets - np.min(targets), sources - np.min(sources)] = values
row = targets - np.min(targets)
col = sources - np.min(sources)
shape = (np.max(targets) - np.min(targets) + 1,
np.max(sources) - np.min(sources) + 1)
# accumulate values and counts separately to compute the mean, avoiding
# the silent last-value-wins overwrite that plain fancy indexing produces
# for multi-synaptic connections
sum_matrix = np.zeros(shape)
count_matrix = np.zeros(shape)
np.add.at(sum_matrix, (row, col), values)
np.add.at(count_matrix, (row, col), 1)
full_matrix = np.where(count_matrix > 0, sum_matrix / count_matrix, np.nan)
masked_matrix = ma.masked_invalid(full_matrix, copy=False)
return masked_matrix

Expand Down