From 41f2562fb96d8e100ef6cbba45e916266ae274e2 Mon Sep 17 00:00:00 2001 From: Sanchit2662 Date: Fri, 27 Mar 2026 23:45:06 +0530 Subject: [PATCH] Fix _float_connection_matrix silently overwriting values for multi-synapse connections Signed-off-by: Sanchit2662 --- brian2tools/plotting/synapses.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/brian2tools/plotting/synapses.py b/brian2tools/plotting/synapses.py index dab12f7c..31ea4c6e 100644 --- a/brian2tools/plotting/synapses.py +++ b/brian2tools/plotting/synapses.py @@ -51,6 +51,9 @@ def _float_connection_matrix(sources, targets, values): weights, delays, ...) in the form of a masked matrix (entries without value are set to NaN and masked). + For source-target pairs with multiple synapses the displayed value is the + mean across all synapses between that pair. + Parameters ---------- sources : ndarray of int @@ -65,9 +68,18 @@ def _float_connection_matrix(sources, targets, values): matrix : ma.MaskedArray The connection matrix, masked for NaN values ''' - full_matrix = np.ones((np.max(targets) - np.min(targets) + 1, - np.max(sources) - np.min(sources) + 1)) * np.nan - full_matrix[targets - np.min(targets), sources - np.min(sources)] = values + row = targets - np.min(targets) + col = sources - np.min(sources) + shape = (np.max(targets) - np.min(targets) + 1, + np.max(sources) - np.min(sources) + 1) + # accumulate values and counts separately to compute the mean, avoiding + # the silent last-value-wins overwrite that plain fancy indexing produces + # for multi-synaptic connections + sum_matrix = np.zeros(shape) + count_matrix = np.zeros(shape) + np.add.at(sum_matrix, (row, col), values) + np.add.at(count_matrix, (row, col), 1) + full_matrix = np.where(count_matrix > 0, sum_matrix / count_matrix, np.nan) masked_matrix = ma.masked_invalid(full_matrix, copy=False) return masked_matrix