Skip to content
Open
7 changes: 6 additions & 1 deletion doc/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ please following citations:

Curation Module
---------------
If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_

If you use the default "similarity_correlograms" preset in the :code:`compute_merge_unit_groups` method from the curation module, please cite [Llobet]_

If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_

If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_

Expand Down Expand Up @@ -139,6 +142,8 @@ References

.. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 <https://journals.physiology.org/doi/full/10.1152/jn.00680.2018>`_

.. [Koukuntla] `SLAy-ing oversplitting errors in high-density electrophysiology spike sorting. 2025. <https://www.biorxiv.org/content/10.1101/2025.06.20.660590v1>`_

.. [Lee] `YASS: Yet another spike sorter. 2017. <https://www.biorxiv.org/content/10.1101/151928v1>`_

.. [Lemon] Methods for neuronal recording in conscious animals. IBRO Handbook Series. 1984.
Expand Down
262 changes: 262 additions & 0 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
"knn",
"quality_score",
],
"slay": [
"template_similarity",
"slay_score",
],
}

_required_extensions = {
Expand All @@ -61,6 +65,7 @@
"snr": ["templates", "noise_levels"],
"template_similarity": ["templates", "template_similarity"],
"knn": ["templates", "spike_locations", "spike_amplitudes"],
"slay_score": ["correlograms", "template_similarity"],
}


Expand All @@ -85,6 +90,7 @@
"censored_period_ms": 0.3,
},
"quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3},
"slay_score": {"k1": 0.25, "k2": 1, "slay_threshold": 0.5},
}


Expand Down Expand Up @@ -363,6 +369,14 @@ def compute_merge_unit_groups(
)
outs["pairs_decreased_score"] = pairs_decreased_score

elif step == "slay_score":

M_ij = compute_slay_matrix(
sorting_analyzer, params["k1"], params["k2"], templates_diff=outs["templates_diff"], pair_mask=pair_mask
)

pair_mask = pair_mask & (M_ij > params["slay_threshold"])

# FINAL STEP : create the final list from pair_mask boolean matrix
ind1, ind2 = np.nonzero(pair_mask)
merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2]))
Expand Down Expand Up @@ -1525,3 +1539,251 @@ def estimate_cross_contamination(
)

return estimation, p_value


def compute_slay_matrix(
sorting_analyzer: SortingAnalyzer,
k1: float,
k2: float,
templates_diff: np.ndarray | None,
pair_mask: np.ndarray | None = None,
):
"""
Computes the "merge decision metric" from the SLAy method, made from combining
a template similarity measure, a cross-correlation significance measure and a
sliding refractory period violation measure. A large M suggests that two
units should be merged.

Paramters
---------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data
k1 : float
Coefficient determining the importance of the cross-correlation significance
k2 : float
Coefficient determining the importance of the sliding rp violation
templates_diff : np.ndarray | None
Pre-computed template similarity difference matrix. If None, it will be retrieved from the sorting_analyzer.
pair_mask : None | np.ndarray, default: None
A bool matrix describing which pairs are possible merges based on previous steps


References
----------
Based on computation originally implemented in SLAy [Koukuntla]_.

Implementation is based on one of the original implementations written by Sai Koukuntla,
found at https://github.com/saikoukunt/SLAy.
"""

num_units = sorting_analyzer.get_num_units()

if pair_mask is None:
pair_mask = np.triu(np.arange(num_units), 1) > 0

if templates_diff is not None:
sigma_ij = 1 - templates_diff
else:
sigma_ij = sorting_analyzer.get_extension("template_similarity").get_data()
rho_ij, eta_ij = compute_xcorr_and_rp(sorting_analyzer, pair_mask)

M_ij = sigma_ij + k1 * rho_ij - k2 * eta_ij

return M_ij


def compute_xcorr_and_rp(sorting_analyzer: SortingAnalyzer, pair_mask: np.ndarray):
"""
Computes a cross-correlation significance measure and a sliding refractory period violation
measure for all units in the `sorting_analyzer`.

Paramters
---------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data
pair_mask : np.ndarray
A bool matrix describing which pairs are possible merges based on previous steps
"""

correlograms_extension = sorting_analyzer.get_extension("correlograms")
ccgs, _ = correlograms_extension.get_data()

# convert to seconds for SLAy functions
bin_size_ms = correlograms_extension.params["bin_ms"]

rho_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])
eta_ij = np.zeros([len(sorting_analyzer.unit_ids), len(sorting_analyzer.unit_ids)])

for unit_index_1, _ in enumerate(sorting_analyzer.unit_ids):
for unit_index_2, _ in enumerate(sorting_analyzer.unit_ids):

# Don't waste time computing the other metrics if units not candidates merges
if not pair_mask[unit_index_1, unit_index_2]:
continue

xgram = ccgs[unit_index_1, unit_index_2, :]

rho_ij[unit_index_1, unit_index_2] = _compute_xcorr_pair(
xgram, bin_size_s=bin_size_ms / 1000, min_xcorr_rate=0
)
eta_ij[unit_index_1, unit_index_2] = _sliding_RP_viol_pair(xgram, bin_size_ms=bin_size_ms)

return rho_ij, eta_ij


def _compute_xcorr_pair(
xgram,
bin_size_s: float,
min_xcorr_rate: float,
) -> float:
"""
Calculates a cross-correlation significance metric for a cluster pair.

Uses the wasserstein distance between an observed cross-correlogram and a null
distribution as an estimate of how significant the dependence between
two neurons is. Low spike count cross-correlograms have large wasserstein
distances from null by chance, so we first try to expand the window size. If
that fails to yield enough spikes, we apply a penalty to the metric.

Ported from https://github.com/saikoukunt/SLAy.

Parameters
----------
xgram : np.array
The raw cross-correlogram for the cluster pair.
bin_size_s : float
The width in seconds of the bin size of the input ccgs.
min_xcorr_rate : float
The minimum ccg firing rate in Hz.

Returns
-------
sig : float
The calculated cross-correlation significance metric.
"""

from scipy.signal import butter, find_peaks_cwt, sosfiltfilt
from scipy.stats import wasserstein_distance

# calculate low-pass filtered second derivative of ccg
fs = 1 / bin_size_s
cutoff_freq = 100
nyqist = fs / 2
cutoff = cutoff_freq / nyqist
peak_width = 0.002 / bin_size_s

xgram_2d = np.diff(xgram, 2)
sos = butter(4, cutoff, output="sos")
xgram_2d = sosfiltfilt(sos, xgram_2d)

if xgram.sum() == 0:
return 0

# find negative peaks of second derivative of ccg, these are the edges of dips in ccg
peaks = find_peaks_cwt(-xgram_2d, peak_width, noise_perc=90) + 1
# if no peaks are found, return a very low significance
if peaks.shape[0] == 0:
return -4
peaks = np.abs(peaks - xgram.shape[0] / 2)
peaks = peaks[peaks > 0.5 * peak_width]
min_peaks = np.sort(peaks)

# start with peaks closest to 0 and move to the next set of peaks if the event count is too low
window_width = min_peaks * 1.5
starts = np.maximum(xgram.shape[0] / 2 - window_width, 0)
ends = np.minimum(xgram.shape[0] / 2 + window_width, xgram.shape[0] - 1)
ind = 0
xgram_window = xgram[int(starts[0]) : int(ends[0] + 1)]
xgram_sum = xgram_window.sum()
window_size = xgram_window.shape[0] * bin_size_s
while (xgram_sum < (min_xcorr_rate * window_size * 10)) and (ind < starts.shape[0]):
xgram_window = xgram[int(starts[ind]) : int(ends[ind] + 1)]
xgram_sum = xgram_window.sum()
window_size = xgram_window.shape[0] * bin_size_s
ind += 1
# use the whole ccg if peak finding fails
if ind == starts.shape[0]:
xgram_window = xgram

# TODO: was getting error messges when xgram_window was all zero. Why was this happening?
if np.abs(xgram_window).sum() == 0:
return 0

sig = (
wasserstein_distance(
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
np.arange(xgram_window.shape[0]) / xgram_window.shape[0],
xgram_window,
np.ones_like(xgram_window),
)
* 4
)

if xgram_window.sum() < (min_xcorr_rate * window_size):
sig *= (xgram_window.sum() / (min_xcorr_rate * window_size)) ** 2

# if sig < 0.04 and xgram_window.sum() < (min_xcorr_rate * window_size):
if xgram_window.sum() < (min_xcorr_rate / 4 * window_size):
sig = -4 # don't merge if the event count is way too low

return sig


def _sliding_RP_viol_pair(
correlogram,
bin_size_ms: float,
accept_threshold: float = 0.15,
) -> float:
"""
Calculate the sliding refractory period violation confidence for a cluster.

Ported from https://github.com/saikoukunt/SLAy.

Parameters
----------
correlogram : np.array
The auto-correlogram of the cluster.
bin_size_ms : float
The width in ms of the bin size of the input ccgs.
accept_threshold : float, default: 0.15
The minimum ccg firing rate in Hz.

Returns
-------
sig : float
The refractory period violation confidence for the cluster.
"""
from scipy.signal import butter, sosfiltfilt
from scipy.stats import poisson

# create various refractory periods sizes to test (between 0 and 20x bin size)
all_refractory_periods = np.arange(0, 21 * bin_size_ms, bin_size_ms) / 1000
test_refractory_period_indices = np.array([1, 2, 4, 6, 8, 12, 16, 20], dtype="int8")
test_refractory_periods = [
all_refractory_periods[test_rp_index] for test_rp_index in test_refractory_period_indices
]

# calculate and avg halves of acg to ensure symmetry
# keep only second half of acg, refractory period violations are compared from the center of acg
half_len = int(correlogram.shape[0] / 2)
correlogram = (correlogram[half_len:] + correlogram[:half_len][::-1]) / 2

acg_cumsum = np.cumsum(correlogram)
sum_res = acg_cumsum[test_refractory_period_indices - 1] # -1 bc 0th bin corresponds to 0-bin_size ms

# low-pass filter acg and use max as baseline event rate
order = 4 # Hz
cutoff_freq = 250 # Hz
fs = 1 / bin_size_ms * 1000
nyqist = fs / 2
cutoff = cutoff_freq / nyqist
sos = butter(order, cutoff, btype="low", output="sos")
smoothed_acg = sosfiltfilt(sos, correlogram)

bin_rate_max = np.max(smoothed_acg)
max_conts_max = np.array(test_refractory_periods) / bin_size_ms * 1000 * (bin_rate_max * accept_threshold)
# compute confidence of less than acceptThresh contamination at each refractory period
confs = 1 - poisson.cdf(sum_res, max_conts_max)
rp_viol = 1 - confs.max()

return rp_viol
4 changes: 2 additions & 2 deletions src/spikeinterface/curation/tests/test_auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@pytest.mark.parametrize(
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None]
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", "slay", None]
)
def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset):

Expand Down Expand Up @@ -59,7 +59,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_with_splits, preset):


@pytest.mark.parametrize(
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"]
"preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", "slay"]
)
def test_compute_merge_unit_groups_multi_segment(sorting_analyzer_multi_segment_for_curation, preset):
job_kwargs = dict(n_jobs=-1)
Expand Down