From 53cb1d598d3002a076c9b2bdf1ba66c4c1a32121 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 4 Dec 2025 16:18:22 +0100 Subject: [PATCH 01/38] Final cleaning --- src/spikeinterface/sorters/internal/spyking_circus2.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index d4f9f39cbd..fbe20fac34 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -480,6 +480,7 @@ def final_cleaning_circus( template_diff_thresh=np.arange(0.05, 0.5, 0.05), debug_folder=None, noise_levels=None, + sd_ratio_threshold=2.0, job_kwargs=dict(), ): @@ -509,4 +510,13 @@ def final_cleaning_circus( **job_kwargs, ) + if sd_ratio_threshold is not None: + from spikeinterface.qualitymetrics.misc_metrics import compute_sd_ratio + final_sa.compute('spike_amplitudes', **job_kwargs) + sd_ratios = compute_sd_ratio(final_sa) + to_keep = [] + for id, value in sd_ratios.items(): + if value < sd_ratio_threshold: + to_keep += [id] + final_sa = final_sa.select_units(to_keep) return final_sa From 17810601c39be684949b19647f4e5b1b17a31721 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:20:19 +0000 Subject: [PATCH 02/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index fbe20fac34..0fb8555648 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -512,7 +512,8 @@ def final_cleaning_circus( if sd_ratio_threshold is not None: from spikeinterface.qualitymetrics.misc_metrics import compute_sd_ratio - final_sa.compute('spike_amplitudes', **job_kwargs) + + final_sa.compute("spike_amplitudes", **job_kwargs) sd_ratios = compute_sd_ratio(final_sa) to_keep = [] for id, value in sd_ratios.items(): From 45341a30537dc125d4984dc1d4ef39aa511a9282 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 4 Dec 2025 21:48:50 +0100 Subject: [PATCH 03/38] WIP --- .../sorters/internal/spyking_circus2.py | 4 +++- .../sortingcomponents/clustering/tools.py | 19 ++++++++++++++- src/spikeinterface/sortingcomponents/tools.py | 23 ++++++++++++++++--- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 0fb8555648..6e4b2a0572 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -365,7 +365,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - dense_templates, new_sparse_mask = get_templates_from_peaks_and_svd( + dense_templates, new_sparse_mask, sd_ratios = get_templates_from_peaks_and_svd( recording_w, selected_peaks, peak_labels, @@ -375,6 +375,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): more_outs["peaks_svd"], more_outs["peak_svd_sparse_mask"], operator="median", + sd_ratios=True, ) # this release the peak_svd memmap file templates = dense_templates.to_sparse(new_sparse_mask) @@ -383,6 +384,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels + cleaning_kwargs["sd_ratios"] = sd_ratios cleaning_kwargs["remove_empty"] = True templates = clean_templates(templates, **cleaning_kwargs) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index fd932ef26c..3572a737f9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -236,6 +236,7 @@ def get_templates_from_peaks_and_svd( svd_features, sparsity_mask, operator="average", + sd_ratios=False, ): """ Get templates from recording using the SVD components @@ -267,6 +268,8 @@ def get_templates_from_peaks_and_svd( The estimated templates object as a dense template (but internanally contain sparse channels). final_sparsity_mask: np.array The final sparsity mask. Note that the template object is dense but with zeros. + sd_ratios: np.array + The standard deviation ratio of the templates at time 0 """ assert operator in ["average", "median"], "operator should be either 'average' or 'median'" @@ -282,6 +285,9 @@ def get_templates_from_peaks_and_svd( num_channels = recording.get_num_channels() templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) + if sd_ratios: + ratios = np.zeros(len(labels), dtype=np.float32) + final_sparsity_mask = np.zeros((len(labels), num_channels), dtype="bool") for unit_ind, label in enumerate(labels): mask = valid_labels == label @@ -298,6 +304,14 @@ def get_templates_from_peaks_and_svd( data = np.median(local_svd[sub_mask, :, count], 0) templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) + if sd_ratios: + count = np.where(best_channel == np.flatnonzero(sparsity_mask[best_channel]))[0][0] + data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) + if len(data) == 1: + ratios[unit_ind] = 0.0 + else: + ratios[unit_ind] = np.std(data[:, nbefore]) + dense_templates = Templates( templates_array=templates_array, sampling_frequency=fs, @@ -309,4 +323,7 @@ def get_templates_from_peaks_and_svd( is_in_uV=False, ) - return dense_templates, final_sparsity_mask + if sd_ratios: + return dense_templates, final_sparsity_mask, ratios + else: + return dense_templates, final_sparsity_mask \ No newline at end of file diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index f1aca75363..d84097907f 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -17,9 +17,9 @@ from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeNoiseLevels -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_extremum_channel from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.sorting_tools import spike_vector_to_indices, get_numba_vector_to_list_of_spiketrain +from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain def make_multi_method_doc(methods, ident=" "): @@ -533,11 +533,28 @@ def get_shuffled_recording_slices(recording, job_kwargs=None, seed=None): def clean_templates( - templates, sparsify_threshold=0.25, noise_levels=None, min_snr=None, max_jitter_ms=None, remove_empty=True + templates, + sparsify_threshold=0.25, + noise_levels=None, + min_snr=None, + max_jitter_ms=None, + remove_empty=True, + sd_ratio_threshold=2.0, + sd_ratios=None, ): """ Clean a Templates object by removing empty units and applying sparsity if provided. """ + ## First if sd_ratios are provided, we remove the templates that have a high sd_ratio + if sd_ratios is not None: + assert noise_levels is not None, "noise_levels must be provided if sd_ratios is given" + to_select = [] + best_channels = get_template_extremum_channel(templates, outputs="index") + for count, unit_id in enumerate(templates.unit_ids): + ratio = sd_ratios[count]/noise_levels[best_channels[unit_id]] + if ratio <= sd_ratio_threshold: + to_select += [unit_id] + templates = templates.select_units(to_select) ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) if sparsify_threshold is not None: From 5bacdae40a67ff3b629f63c45bc2c90b90923ad3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Dec 2025 20:49:28 +0000 Subject: [PATCH 04/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/tools.py | 4 ++-- src/spikeinterface/sortingcomponents/tools.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 3572a737f9..73e3d9b5c6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -287,7 +287,7 @@ def get_templates_from_peaks_and_svd( templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) if sd_ratios: ratios = np.zeros(len(labels), dtype=np.float32) - + final_sparsity_mask = np.zeros((len(labels), num_channels), dtype="bool") for unit_ind, label in enumerate(labels): mask = valid_labels == label @@ -326,4 +326,4 @@ def get_templates_from_peaks_and_svd( if sd_ratios: return dense_templates, final_sparsity_mask, ratios else: - return dense_templates, final_sparsity_mask \ No newline at end of file + return dense_templates, final_sparsity_mask diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index d84097907f..9c8efa9c41 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -533,11 +533,11 @@ def get_shuffled_recording_slices(recording, job_kwargs=None, seed=None): def clean_templates( - templates, - sparsify_threshold=0.25, - noise_levels=None, - min_snr=None, - max_jitter_ms=None, + templates, + sparsify_threshold=0.25, + noise_levels=None, + min_snr=None, + max_jitter_ms=None, remove_empty=True, sd_ratio_threshold=2.0, sd_ratios=None, @@ -551,7 +551,7 @@ def clean_templates( to_select = [] best_channels = get_template_extremum_channel(templates, outputs="index") for count, unit_id in enumerate(templates.unit_ids): - ratio = sd_ratios[count]/noise_levels[best_channels[unit_id]] + ratio = sd_ratios[count] / noise_levels[best_channels[unit_id]] if ratio <= sd_ratio_threshold: to_select += [unit_id] templates = templates.select_units(to_select) From 3c1157e6f2632b6b44edd55b51738ac518e4716a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Dec 2025 09:21:18 +0100 Subject: [PATCH 05/38] WIP --- .../sorters/internal/spyking_circus2.py | 13 ++++++++--- .../sortingcomponents/clustering/tools.py | 15 ++++++------- src/spikeinterface/sortingcomponents/tools.py | 22 ++++++++++++------- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6e4b2a0572..9035d5eaf4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -328,6 +328,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): method=clustering_method, method_kwargs=clustering_params, extra_outputs=True, + verbose=verbose, job_kwargs=job_kwargs, ) @@ -381,10 +382,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = dense_templates.to_sparse(new_sparse_mask) del more_outs + + if verbose: + print("We have %d clusters" % len(templates.unit_ids)) cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels - cleaning_kwargs["sd_ratios"] = sd_ratios + #cleaning_kwargs["sd_ratios"] = sd_ratios cleaning_kwargs["remove_empty"] = True templates = clean_templates(templates, **cleaning_kwargs) @@ -482,7 +486,7 @@ def final_cleaning_circus( template_diff_thresh=np.arange(0.05, 0.5, 0.05), debug_folder=None, noise_levels=None, - sd_ratio_threshold=2.0, + sd_ratio_threshold=3.0, job_kwargs=dict(), ): @@ -517,9 +521,12 @@ def final_cleaning_circus( final_sa.compute("spike_amplitudes", **job_kwargs) sd_ratios = compute_sd_ratio(final_sa) + ratios = np.array(list(sd_ratios.values())) + center = np.median(ratios) + mad = np.median(np.abs(ratios - center)) to_keep = [] for id, value in sd_ratios.items(): - if value < sd_ratio_threshold: + if value <= center + mad*sd_ratio_threshold: to_keep += [id] final_sa = final_sa.select_units(to_keep) return final_sa diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 73e3d9b5c6..b9e3e27968 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -303,14 +303,13 @@ def get_templates_from_peaks_and_svd( elif operator == "median": data = np.median(local_svd[sub_mask, :, count], 0) templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) - - if sd_ratios: - count = np.where(best_channel == np.flatnonzero(sparsity_mask[best_channel]))[0][0] - data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) - if len(data) == 1: - ratios[unit_ind] = 0.0 - else: - ratios[unit_ind] = np.std(data[:, nbefore]) + + if i == best_channel and sd_ratios: + data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) + if len(data) == 1: + ratios[unit_ind] = 0.0 + else: + ratios[unit_ind] = np.std(data[:, nbefore]) dense_templates = Templates( templates_array=templates_array, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 9c8efa9c41..f72648e93a 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -533,13 +533,13 @@ def get_shuffled_recording_slices(recording, job_kwargs=None, seed=None): def clean_templates( - templates, - sparsify_threshold=0.25, - noise_levels=None, - min_snr=None, - max_jitter_ms=None, + templates, + sparsify_threshold=0.25, + noise_levels=None, + min_snr=None, + max_jitter_ms=None, remove_empty=True, - sd_ratio_threshold=2.0, + sd_ratio_threshold=3.0, sd_ratios=None, ): """ @@ -550,9 +550,15 @@ def clean_templates( assert noise_levels is not None, "noise_levels must be provided if sd_ratios is given" to_select = [] best_channels = get_template_extremum_channel(templates, outputs="index") + all_ratios = [] for count, unit_id in enumerate(templates.unit_ids): - ratio = sd_ratios[count] / noise_levels[best_channels[unit_id]] - if ratio <= sd_ratio_threshold: + ratio = sd_ratios[count]/noise_levels[best_channels[unit_id]] + all_ratios.append(ratio) + centered_ratios = np.array(all_ratios) - 1.0 + mad = np.median(np.abs(centered_ratios - np.median(centered_ratios))) + + for count, unit_id in enumerate(templates.unit_ids): + if all_ratios[count] <= 1 + sd_ratio_threshold * mad: to_select += [unit_id] templates = templates.select_units(to_select) From bb0df41fcc14369a6043e04feb929a36a9e28135 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 08:21:48 +0000 Subject: [PATCH 06/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 6 +++--- .../sortingcomponents/clustering/tools.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9035d5eaf4..331d7fc068 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -382,13 +382,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = dense_templates.to_sparse(new_sparse_mask) del more_outs - + if verbose: print("We have %d clusters" % len(templates.unit_ids)) cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels - #cleaning_kwargs["sd_ratios"] = sd_ratios + # cleaning_kwargs["sd_ratios"] = sd_ratios cleaning_kwargs["remove_empty"] = True templates = clean_templates(templates, **cleaning_kwargs) @@ -526,7 +526,7 @@ def final_cleaning_circus( mad = np.median(np.abs(ratios - center)) to_keep = [] for id, value in sd_ratios.items(): - if value <= center + mad*sd_ratio_threshold: + if value <= center + mad * sd_ratio_threshold: to_keep += [id] final_sa = final_sa.select_units(to_keep) return final_sa diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index b9e3e27968..56c9f3fa45 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -303,7 +303,7 @@ def get_templates_from_peaks_and_svd( elif operator == "median": data = np.median(local_svd[sub_mask, :, count], 0) templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) - + if i == best_channel and sd_ratios: data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) if len(data) == 1: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index f72648e93a..2ade07e355 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -533,11 +533,11 @@ def get_shuffled_recording_slices(recording, job_kwargs=None, seed=None): def clean_templates( - templates, - sparsify_threshold=0.25, - noise_levels=None, - min_snr=None, - max_jitter_ms=None, + templates, + sparsify_threshold=0.25, + noise_levels=None, + min_snr=None, + max_jitter_ms=None, remove_empty=True, sd_ratio_threshold=3.0, sd_ratios=None, @@ -552,7 +552,7 @@ def clean_templates( best_channels = get_template_extremum_channel(templates, outputs="index") all_ratios = [] for count, unit_id in enumerate(templates.unit_ids): - ratio = sd_ratios[count]/noise_levels[best_channels[unit_id]] + ratio = sd_ratios[count] / noise_levels[best_channels[unit_id]] all_ratios.append(ratio) centered_ratios = np.array(all_ratios) - 1.0 mad = np.median(np.abs(centered_ratios - np.median(centered_ratios))) From 1b90a465e62d3032728f082b6cf816be00dc039e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Dec 2025 10:58:33 +0100 Subject: [PATCH 07/38] Cosmetic --- src/spikeinterface/sorters/internal/spyking_circus2.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9035d5eaf4..d3b8c1b7da 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -118,8 +118,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_before = params["general"].get("ms_before", 0.5) ms_after = params["general"].get("ms_after", 1.5) radius_um = params["general"].get("radius_um", 100.0) - detect_threshold = params["detection"]["method_kwargs"].get("detect_threshold", 5) - peak_sign = params["detection"].get("peak_sign", "neg") deterministic = params["deterministic_peaks_detection"] debug = params["debug"] seed = params["seed"] @@ -383,8 +381,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): del more_outs - if verbose: - print("We have %d clusters" % len(templates.unit_ids)) + #if verbose: + # print("We have %d clusters" % len(templates.unit_ids)) cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels From 268c7a325269e6dc9eff228abb7cdccdea8076a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 10:00:10 +0000 Subject: [PATCH 08/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index dac5daf9b8..e2288e5808 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -380,8 +380,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = dense_templates.to_sparse(new_sparse_mask) del more_outs - - #if verbose: + + # if verbose: # print("We have %d clusters" % len(templates.unit_ids)) cleaning_kwargs = params.get("cleaning", {}).copy() From 97300507fe1f54abdf156c62b2e0d98713332024 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Dec 2025 11:23:54 +0100 Subject: [PATCH 09/38] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- src/spikeinterface/sorters/internal/tridesclous2.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index dac5daf9b8..f4d08c0a27 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsity_threshold": 1}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -484,7 +484,7 @@ def final_cleaning_circus( template_diff_thresh=np.arange(0.05, 0.5, 0.05), debug_folder=None, noise_levels=None, - sd_ratio_threshold=3.0, + sd_ratio_threshold=None, job_kwargs=dict(), ): diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 2f7e84dcf4..c101c67ff5 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -93,7 +93,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods - from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.preprocessing import correct_motion from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label From 73f9371e2d4c09a7298abf92d83951c608661d1c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Dec 2025 11:34:00 +0100 Subject: [PATCH 10/38] WIP --- .../sorters/internal/spyking_circus2.py | 13 +++++------ .../sortingcomponents/clustering/tools.py | 22 ++++++++++--------- src/spikeinterface/sortingcomponents/tools.py | 18 +++++---------- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 84120e9577..2aee5b4d76 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsity_threshold": 1}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": 1}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -364,7 +364,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - dense_templates, new_sparse_mask, sd_ratios = get_templates_from_peaks_and_svd( + dense_templates, new_sparse_mask, std_at_peaks = get_templates_from_peaks_and_svd( recording_w, selected_peaks, peak_labels, @@ -374,7 +374,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): more_outs["peaks_svd"], more_outs["peak_svd_sparse_mask"], operator="median", - sd_ratios=True, + return_std_at_peaks=True, ) # this release the peak_svd memmap file templates = dense_templates.to_sparse(new_sparse_mask) @@ -386,7 +386,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels - # cleaning_kwargs["sd_ratios"] = sd_ratios + # cleaning_kwargs["std_at_peaks"] = std_at_peaks cleaning_kwargs["remove_empty"] = True templates = clean_templates(templates, **cleaning_kwargs) @@ -519,12 +519,9 @@ def final_cleaning_circus( final_sa.compute("spike_amplitudes", **job_kwargs) sd_ratios = compute_sd_ratio(final_sa) - ratios = np.array(list(sd_ratios.values())) - center = np.median(ratios) - mad = np.median(np.abs(ratios - center)) to_keep = [] for id, value in sd_ratios.items(): - if value <= center + mad * sd_ratio_threshold: + if value <= sd_ratio_threshold: to_keep += [id] final_sa = final_sa.select_units(to_keep) return final_sa diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 56c9f3fa45..705383ffe8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -236,7 +236,7 @@ def get_templates_from_peaks_and_svd( svd_features, sparsity_mask, operator="average", - sd_ratios=False, + return_std_at_peaks=False, ): """ Get templates from recording using the SVD components @@ -261,14 +261,16 @@ def get_templates_from_peaks_and_svd( The sparsity mask array. operator : str The operator to use for template estimation. Can be 'average' or 'median'. - + return_std_at_peaks : bool + Whether to return the standard deviation ratio at the peak channels. + Returns ------- dense_templates : Templates The estimated templates object as a dense template (but internanally contain sparse channels). final_sparsity_mask: np.array The final sparsity mask. Note that the template object is dense but with zeros. - sd_ratios: np.array + std_at_peaks: np.array The standard deviation ratio of the templates at time 0 """ @@ -285,8 +287,8 @@ def get_templates_from_peaks_and_svd( num_channels = recording.get_num_channels() templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) - if sd_ratios: - ratios = np.zeros(len(labels), dtype=np.float32) + if return_std_at_peaks: + std_at_peaks = np.zeros(len(labels), dtype=np.float32) final_sparsity_mask = np.zeros((len(labels), num_channels), dtype="bool") for unit_ind, label in enumerate(labels): @@ -304,12 +306,12 @@ def get_templates_from_peaks_and_svd( data = np.median(local_svd[sub_mask, :, count], 0) templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) - if i == best_channel and sd_ratios: + if i == best_channel and return_std_at_peaks: data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) if len(data) == 1: - ratios[unit_ind] = 0.0 + std_at_peaks[unit_ind] = 0.0 else: - ratios[unit_ind] = np.std(data[:, nbefore]) + std_at_peaks[unit_ind] = np.std(data[:, nbefore]) dense_templates = Templates( templates_array=templates_array, @@ -322,7 +324,7 @@ def get_templates_from_peaks_and_svd( is_in_uV=False, ) - if sd_ratios: - return dense_templates, final_sparsity_mask, ratios + if return_std_at_peaks: + return dense_templates, final_sparsity_mask, std_at_peaks else: return dense_templates, final_sparsity_mask diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 2ade07e355..fb4b636ee4 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -539,26 +539,20 @@ def clean_templates( min_snr=None, max_jitter_ms=None, remove_empty=True, - sd_ratio_threshold=3.0, - sd_ratios=None, + sd_ratio_threshold=5.0, + stds_at_peak=None, ): """ Clean a Templates object by removing empty units and applying sparsity if provided. """ - ## First if sd_ratios are provided, we remove the templates that have a high sd_ratio - if sd_ratios is not None: + ## First if stds_at_peak are provided, we remove the templates that have a high sd_ratio + if stds_at_peak is not None: assert noise_levels is not None, "noise_levels must be provided if sd_ratios is given" to_select = [] best_channels = get_template_extremum_channel(templates, outputs="index") - all_ratios = [] for count, unit_id in enumerate(templates.unit_ids): - ratio = sd_ratios[count] / noise_levels[best_channels[unit_id]] - all_ratios.append(ratio) - centered_ratios = np.array(all_ratios) - 1.0 - mad = np.median(np.abs(centered_ratios - np.median(centered_ratios))) - - for count, unit_id in enumerate(templates.unit_ids): - if all_ratios[count] <= 1 + sd_ratio_threshold * mad: + sd_ratio = stds_at_peak[count] / noise_levels[best_channels[unit_id]] + if sd_ratio <= sd_ratio_threshold: to_select += [unit_id] templates = templates.select_units(to_select) From d2e3ba5f3f17455e3e9738772fd498db5fa2b6ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 10:34:31 +0000 Subject: [PATCH 11/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 705383ffe8..53ec941b49 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -263,7 +263,7 @@ def get_templates_from_peaks_and_svd( The operator to use for template estimation. Can be 'average' or 'median'. return_std_at_peaks : bool Whether to return the standard deviation ratio at the peak channels. - + Returns ------- dense_templates : Templates From fc50c1b8d2f42a4e144c9fe586d131f2bf9ec0c8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Dec 2025 13:35:26 +0100 Subject: [PATCH 12/38] Fixes --- src/spikeinterface/sorters/internal/spyking_circus2.py | 7 ++----- src/spikeinterface/sortingcomponents/tools.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2aee5b4d76..bc08d23677 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -364,7 +364,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - dense_templates, new_sparse_mask, std_at_peaks = get_templates_from_peaks_and_svd( + dense_templates, new_sparse_mask, stds_at_peak = get_templates_from_peaks_and_svd( recording_w, selected_peaks, peak_labels, @@ -381,12 +381,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): del more_outs - # if verbose: - # print("We have %d clusters" % len(templates.unit_ids)) - cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels - # cleaning_kwargs["std_at_peaks"] = std_at_peaks + #cleaning_kwargs["stds_at_peak"] = stds_at_peak cleaning_kwargs["remove_empty"] = True templates = clean_templates(templates, **cleaning_kwargs) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index fb4b636ee4..4ffec42f61 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -539,7 +539,7 @@ def clean_templates( min_snr=None, max_jitter_ms=None, remove_empty=True, - sd_ratio_threshold=5.0, + sd_ratio_threshold=3.0, stds_at_peak=None, ): """ From 984cb7f0f1b071fa03592965125f7886b7a2bd24 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 8 Dec 2025 13:50:47 +0100 Subject: [PATCH 13/38] clean template before merging --- src/spikeinterface/sorters/internal/lupin.py | 3 ++- .../sorters/internal/tridesclous2.py | 3 ++- .../clustering/iterative_hdbscan.py | 26 ++++++++++++++++++- .../clustering/iterative_isosplit.py | 25 ++++++++++++++++++ 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index f7fb4d705b..1825e82ba4 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -56,6 +56,7 @@ class LupinSorter(ComponentsBasedSorter): "ms_after": 2.5, "sparsity_threshold": 1.5, "template_min_snr": 2.5, + "template_max_jitter_ms": 0.2, "gather_mode": "memory", "job_kwargs": {}, "seed": None, @@ -293,7 +294,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsify_threshold=params["sparsity_threshold"], noise_levels=noise_levels, min_snr=params["template_min_snr"], - max_jitter_ms=None, + max_jitter_ms=params["template_max_jitter_ms"], remove_empty=True, ) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index c101c67ff5..d955bcd7d4 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -58,6 +58,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "sparsity_threshold": 1.5, "min_snr": 2.5, "radius_um": 100.0, + "max_jitter_ms": 0.2, }, "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, "job_kwargs": {}, @@ -267,7 +268,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsify_threshold=params["templates"]["sparsity_threshold"], noise_levels=noise_levels, min_snr=params["templates"]["min_snr"], - max_jitter_ms=None, + max_jitter_ms=params["templates"]["max_jitter_ms"], remove_empty=True, ) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 0c089229ee..712dc5dbb2 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -10,7 +10,7 @@ from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - +from spikeinterface.sortingcomponents.tools import clean_templates class IterativeHDBSCANClustering: """ @@ -43,6 +43,9 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, + "pre_clean_templates":{ + "max_jitter_ms" : 0.2, + }, "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), "merge_from_features": None, "debug_folder": None, @@ -128,6 +131,27 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="median", ) + ## Pre clean using templates (jitter) + cleaned_templates = clean_templates( + templates, + # sparsify_threshold=0.25, + sparsify_threshold=None, + # noise_levels=None, + # min_snr=None, + max_jitter_ms=params["pre_clean_templates"]["max_jitter_ms"], + # remove_empty=True, + remove_empty=False, + # sd_ratio_threshold=5.0, + # stds_at_peak=None, + ) + mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) + to_remove_ids = templates.unit_ids[~mask_keep_ids] + to_remove_label_mask = np.isin(peak_labels, to_remove_ids) + peak_labels[to_remove_label_mask] = -1 + templates = cleaned_templates + new_sparse_mask = new_sparse_mask[mask_keep_ids, :] + + labels = templates.unit_ids if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 6e08548329..5efdab5cfb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -11,6 +11,7 @@ merge_peak_labels_from_features, ) from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd +from spikeinterface.sortingcomponents.tools import clean_templates from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd @@ -59,6 +60,9 @@ class IterativeISOSPLITClustering: # "projection_mode": "pca", }, }, + "pre_clean_templates":{ + "max_jitter_ms" : 0.2, + }, "merge_from_templates": { "similarity_metric": "l1", "num_shifts": 3, @@ -206,6 +210,27 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="average", ) + ## Pre clean using templates (jitter) + cleaned_templates = clean_templates( + dense_templates, + # sparsify_threshold=0.25, + sparsify_threshold=None, + # noise_levels=None, + # min_snr=None, + max_jitter_ms=params["pre_clean_templates"]["max_jitter_ms"], + # remove_empty=True, + remove_empty=False, + # sd_ratio_threshold=5.0, + # stds_at_peak=None, + ) + mask_keep_ids = np.isin(dense_templates.unit_ids, cleaned_templates.unit_ids) + to_remove_ids = dense_templates.unit_ids[~mask_keep_ids] + to_remove_label_mask = np.isin(post_split_label, to_remove_ids) + post_split_label[to_remove_label_mask] = -1 + dense_templates = cleaned_templates + template_sparse_mask = template_sparse_mask[mask_keep_ids, :] + + unit_ids = dense_templates.unit_ids templates_array = dense_templates.templates_array From 2725d5bb48309f34a29ae15528bce77213b45c51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 12:52:22 +0000 Subject: [PATCH 14/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/iterative_hdbscan.py | 6 +++--- .../sortingcomponents/clustering/iterative_isosplit.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 712dc5dbb2..511ecbfeed 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -12,6 +12,7 @@ from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd from spikeinterface.sortingcomponents.tools import clean_templates + class IterativeHDBSCANClustering: """ Iterative HDBSCAN is based on several local clustering achieved with a @@ -43,8 +44,8 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, - "pre_clean_templates":{ - "max_jitter_ms" : 0.2, + "pre_clean_templates": { + "max_jitter_ms": 0.2, }, "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), "merge_from_features": None, @@ -151,7 +152,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): templates = cleaned_templates new_sparse_mask = new_sparse_mask[mask_keep_ids, :] - labels = templates.unit_ids if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 5efdab5cfb..5da3f297a7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -60,8 +60,8 @@ class IterativeISOSPLITClustering: # "projection_mode": "pca", }, }, - "pre_clean_templates":{ - "max_jitter_ms" : 0.2, + "pre_clean_templates": { + "max_jitter_ms": 0.2, }, "merge_from_templates": { "similarity_metric": "l1", @@ -230,7 +230,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): dense_templates = cleaned_templates template_sparse_mask = template_sparse_mask[mask_keep_ids, :] - unit_ids = dense_templates.unit_ids templates_array = dense_templates.templates_array From b48f9ecd661cdb875c205a9858e4dd3442c2abe1 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 8 Dec 2025 17:07:35 +0100 Subject: [PATCH 15/38] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2aee5b4d76..b91a922941 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": 1}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, From b7006eb899f82e2225c990bdb8de4bb97afd2095 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 07:33:25 +0000 Subject: [PATCH 16/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2dfe0db750..44b3f34717 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -383,7 +383,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): cleaning_kwargs = params.get("cleaning", {}).copy() cleaning_kwargs["noise_levels"] = noise_levels - #cleaning_kwargs["stds_at_peak"] = stds_at_peak + # cleaning_kwargs["stds_at_peak"] = stds_at_peak cleaning_kwargs["remove_empty"] = True templates = clean_templates(templates, **cleaning_kwargs) From 90010546a50e2c437965cacfc61750ff0aa1366d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 9 Dec 2025 09:38:05 +0100 Subject: [PATCH 17/38] Cleaning with max std per channel --- .../sorters/internal/spyking_circus2.py | 30 +++++++------------ .../clustering/iterative_hdbscan.py | 30 +++++++++++-------- .../sortingcomponents/clustering/tools.py | 26 ++++++++-------- src/spikeinterface/sortingcomponents/tools.py | 26 ++++++++-------- 4 files changed, 54 insertions(+), 58 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2dfe0db750..ae00bdd749 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1}, + "cleaning": {"min_snr": 2.5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold" : 3}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -308,6 +308,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("Kept %d peaks for clustering" % len(selected_peaks)) + cleaning_kwargs = params.get("cleaning", {}).copy() + cleaning_kwargs["noise_levels"] = noise_levels + cleaning_kwargs["remove_empty"] = True + if clustering_method in [ "iterative-hdbscan", "iterative-isosplit", @@ -317,6 +321,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(verbose=verbose) clustering_params.update(seed=seed) clustering_params.update(peaks_svd=params["general"]) + if clustering_method == "iterative-hdbscan": + clustering_params.update(pre_clean_templates=cleaning_kwargs) if debug: clustering_params["debug_folder"] = sorter_output_folder / "clustering" @@ -364,7 +370,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - dense_templates, new_sparse_mask, stds_at_peak = get_templates_from_peaks_and_svd( + dense_templates, new_sparse_mask, max_std_per_channel = get_templates_from_peaks_and_svd( recording_w, selected_peaks, peak_labels, @@ -374,17 +380,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): more_outs["peaks_svd"], more_outs["peak_svd_sparse_mask"], operator="median", - return_std_at_peaks=True, + return_max_std_per_channel=True, ) # this release the peak_svd memmap file templates = dense_templates.to_sparse(new_sparse_mask) del more_outs - cleaning_kwargs = params.get("cleaning", {}).copy() - cleaning_kwargs["noise_levels"] = noise_levels - #cleaning_kwargs["stds_at_peak"] = stds_at_peak - cleaning_kwargs["remove_empty"] = True + cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + templates = clean_templates(templates, **cleaning_kwargs) if verbose: @@ -481,7 +485,6 @@ def final_cleaning_circus( template_diff_thresh=np.arange(0.05, 0.5, 0.05), debug_folder=None, noise_levels=None, - sd_ratio_threshold=None, job_kwargs=dict(), ): @@ -510,15 +513,4 @@ def final_cleaning_circus( sparsity_overlap=sparsity_overlap, **job_kwargs, ) - - if sd_ratio_threshold is not None: - from spikeinterface.qualitymetrics.misc_metrics import compute_sd_ratio - - final_sa.compute("spike_amplitudes", **job_kwargs) - sd_ratios = compute_sd_ratio(final_sa) - to_keep = [] - for id, value in sd_ratios.items(): - if value <= sd_ratio_threshold: - to_keep += [id] - final_sa = final_sa.select_units(to_keep) return final_sa diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 511ecbfeed..ddda67512c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -11,7 +11,7 @@ from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd from spikeinterface.sortingcomponents.tools import clean_templates - +from spikeinterface.core.recording_tools import get_noise_levels class IterativeHDBSCANClustering: """ @@ -45,6 +45,8 @@ class IterativeHDBSCANClustering: }, }, "pre_clean_templates": { + "sparsify_threshold": 1, + "remove_empty" : True, "max_jitter_ms": 0.2, }, "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), @@ -73,6 +75,7 @@ class IterativeHDBSCANClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): + print(params) split_radius_um = params["split"].pop("split_radius_um", 75) peaks_svd = params["peaks_svd"] ms_before = peaks_svd["ms_before"] @@ -120,7 +123,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **split, ) - templates, new_sparse_mask = get_templates_from_peaks_and_svd( + templates, new_sparse_mask, max_std_per_channel = get_templates_from_peaks_and_svd( recording, peaks, peak_labels, @@ -130,28 +133,29 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peaks_svd, sparse_mask, operator="median", + return_max_std_per_channel=True ) + templates = templates.to_sparse(new_sparse_mask) + + cleaning_kwargs = params.get("pre_clean_templates", {}).copy() + if "noise_levels" not in cleaning_kwargs: + noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) + cleaning_kwargs["noise_levels"] = noise_levels + cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + ## Pre clean using templates (jitter) cleaned_templates = clean_templates( templates, - # sparsify_threshold=0.25, - sparsify_threshold=None, - # noise_levels=None, - # min_snr=None, - max_jitter_ms=params["pre_clean_templates"]["max_jitter_ms"], - # remove_empty=True, - remove_empty=False, - # sd_ratio_threshold=5.0, - # stds_at_peak=None, + **cleaning_kwargs, ) mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) to_remove_ids = templates.unit_ids[~mask_keep_ids] to_remove_label_mask = np.isin(peak_labels, to_remove_ids) peak_labels[to_remove_label_mask] = -1 templates = cleaned_templates - new_sparse_mask = new_sparse_mask[mask_keep_ids, :] - + new_sparse_mask = templates.sparsity.mask.copy() + templates = templates.to_dense() labels = templates.unit_ids if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 53ec941b49..f6a6029d0d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -236,7 +236,7 @@ def get_templates_from_peaks_and_svd( svd_features, sparsity_mask, operator="average", - return_std_at_peaks=False, + return_max_std_per_channel=False, ): """ Get templates from recording using the SVD components @@ -261,8 +261,8 @@ def get_templates_from_peaks_and_svd( The sparsity mask array. operator : str The operator to use for template estimation. Can be 'average' or 'median'. - return_std_at_peaks : bool - Whether to return the standard deviation ratio at the peak channels. + return_max_std_per_channel : bool + Whether to return the max standard deviation at the channels. Returns ------- @@ -270,8 +270,8 @@ def get_templates_from_peaks_and_svd( The estimated templates object as a dense template (but internanally contain sparse channels). final_sparsity_mask: np.array The final sparsity mask. Note that the template object is dense but with zeros. - std_at_peaks: np.array - The standard deviation ratio of the templates at time 0 + max_std_per_channel: np.array + The maximal standard deviation of the templates per channel (only if return_max_std_per_channel is True). """ assert operator in ["average", "median"], "operator should be either 'average' or 'median'" @@ -287,8 +287,8 @@ def get_templates_from_peaks_and_svd( num_channels = recording.get_num_channels() templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) - if return_std_at_peaks: - std_at_peaks = np.zeros(len(labels), dtype=np.float32) + if return_max_std_per_channel: + max_std_per_channel = np.zeros((len(labels), num_channels), dtype=np.float32) final_sparsity_mask = np.zeros((len(labels), num_channels), dtype="bool") for unit_ind, label in enumerate(labels): @@ -306,12 +306,10 @@ def get_templates_from_peaks_and_svd( data = np.median(local_svd[sub_mask, :, count], 0) templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) - if i == best_channel and return_std_at_peaks: + if return_max_std_per_channel: data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) - if len(data) == 1: - std_at_peaks[unit_ind] = 0.0 - else: - std_at_peaks[unit_ind] = np.std(data[:, nbefore]) + if len(data) > 1: + max_std_per_channel[unit_ind, i] = np.std(data, 0).max() dense_templates = Templates( templates_array=templates_array, @@ -324,7 +322,7 @@ def get_templates_from_peaks_and_svd( is_in_uV=False, ) - if return_std_at_peaks: - return dense_templates, final_sparsity_mask, std_at_peaks + if return_max_std_per_channel: + return dense_templates, final_sparsity_mask, max_std_per_channel else: return dense_templates, final_sparsity_mask diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 4ffec42f61..323914241a 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -539,25 +539,16 @@ def clean_templates( min_snr=None, max_jitter_ms=None, remove_empty=True, - sd_ratio_threshold=3.0, - stds_at_peak=None, + mean_sd_ratio_threshold=3.0, + max_std_per_channel=None, ): """ Clean a Templates object by removing empty units and applying sparsity if provided. """ - ## First if stds_at_peak are provided, we remove the templates that have a high sd_ratio - if stds_at_peak is not None: - assert noise_levels is not None, "noise_levels must be provided if sd_ratios is given" - to_select = [] - best_channels = get_template_extremum_channel(templates, outputs="index") - for count, unit_id in enumerate(templates.unit_ids): - sd_ratio = stds_at_peak[count] / noise_levels[best_channels[unit_id]] - if sd_ratio <= sd_ratio_threshold: - to_select += [unit_id] - templates = templates.select_units(to_select) ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) if sparsify_threshold is not None: + assert noise_levels is not None, "noise_levels must be provided if sparsify_threshold is set" if templates.are_templates_sparse(): templates = templates.to_dense() sparsity = compute_sparsity( @@ -597,6 +588,17 @@ def clean_templates( to_select = templates.unit_ids[np.flatnonzero(sparsity.mask.sum(axis=1) > 0)] templates = templates.select_units(to_select) + ## Lastly, if stds_at_peak are provided, we remove the templates that have a too high sd_ratio + if max_std_per_channel is not None: + assert noise_levels is not None, "noise_levels must be provided if max_std_per_channel is given" + to_select = [] + for count, unit_id in enumerate(templates.unit_ids): + mask = templates.sparsity.mask[count, :] + sd_ratio = np.mean(max_std_per_channel[count][mask] / noise_levels[mask]) + if sd_ratio <= mean_sd_ratio_threshold: + to_select += [unit_id] + templates = templates.select_units(to_select) + return templates From 7201a89f2f228da10001f277e9040e01b71cfb25 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 08:40:26 +0000 Subject: [PATCH 18/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- .../sortingcomponents/clustering/iterative_hdbscan.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ae00bdd749..ebf19fafb8 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 2.5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold" : 3}, + "cleaning": {"min_snr": 2.5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -388,7 +388,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): del more_outs cleaning_kwargs["max_std_per_channel"] = max_std_per_channel - + templates = clean_templates(templates, **cleaning_kwargs) if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index ddda67512c..c06dc4e1d8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -13,6 +13,7 @@ from spikeinterface.sortingcomponents.tools import clean_templates from spikeinterface.core.recording_tools import get_noise_levels + class IterativeHDBSCANClustering: """ Iterative HDBSCAN is based on several local clustering achieved with a @@ -46,7 +47,7 @@ class IterativeHDBSCANClustering: }, "pre_clean_templates": { "sparsify_threshold": 1, - "remove_empty" : True, + "remove_empty": True, "max_jitter_ms": 0.2, }, "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), @@ -133,7 +134,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peaks_svd, sparse_mask, operator="median", - return_max_std_per_channel=True + return_max_std_per_channel=True, ) templates = templates.to_sparse(new_sparse_mask) From 2f5dd0d6bde6fef01ac0477d0a4349ab95ccad5e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 9 Dec 2025 11:34:41 +0100 Subject: [PATCH 19/38] WIP --- .../clustering/iterative_hdbscan.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 25 +++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index ddda67512c..8311ff9b18 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -75,7 +75,6 @@ class IterativeHDBSCANClustering: @classmethod def main_function(cls, recording, peaks, params, job_kwargs=dict()): - print(params) split_radius_um = params["split"].pop("split_radius_um", 75) peaks_svd = params["peaks_svd"] ms_before = peaks_svd["ms_before"] @@ -148,6 +147,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): cleaned_templates = clean_templates( templates, **cleaning_kwargs, + verbose=True ) mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) to_remove_ids = templates.unit_ids[~mask_keep_ids] diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 323914241a..b603921687 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -541,10 +541,13 @@ def clean_templates( remove_empty=True, mean_sd_ratio_threshold=3.0, max_std_per_channel=None, + verbose=True ): """ Clean a Templates object by removing empty units and applying sparsity if provided. """ + + initial_ids = templates.unit_ids.copy() ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) if sparsify_threshold is not None: @@ -562,22 +565,30 @@ def clean_templates( ## We removed non empty templates if remove_empty: + n_before = len(templates.unit_ids) templates = remove_empty_templates(templates) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} empty templates") ## We keep only units with a max jitter if max_jitter_ms is not None: max_jitter = int(max_jitter_ms * templates.sampling_frequency / 1000.0) - + n_before = len(templates.unit_ids) shifts = get_template_extremum_channel_peak_shift(templates) to_select = [] for unit_id in templates.unit_ids: if np.abs(shifts[unit_id]) <= max_jitter: to_select += [unit_id] templates = templates.select_units(to_select) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} unaligned templates") ## We remove units with a low SNR if min_snr is not None: assert noise_levels is not None, "noise_levels must be provided if min_snr is set" + n_before = len(templates.unit_ids) sparsity = compute_sparsity( templates.to_dense(), method="snr", @@ -587,17 +598,27 @@ def clean_templates( ) to_select = templates.unit_ids[np.flatnonzero(sparsity.mask.sum(axis=1) > 0)] templates = templates.select_units(to_select) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} templates with too low SNR") ## Lastly, if stds_at_peak are provided, we remove the templates that have a too high sd_ratio if max_std_per_channel is not None: assert noise_levels is not None, "noise_levels must be provided if max_std_per_channel is given" to_select = [] + n_before = len(templates.unit_ids) + all_ratios = [] for count, unit_id in enumerate(templates.unit_ids): + old_index = np.where(unit_id == initial_ids)[0][0] mask = templates.sparsity.mask[count, :] - sd_ratio = np.mean(max_std_per_channel[count][mask] / noise_levels[mask]) + sd_ratio = np.mean(max_std_per_channel[old_index][mask] / noise_levels[mask]) + all_ratios += [sd_ratio] if sd_ratio <= mean_sd_ratio_threshold: to_select += [unit_id] templates = templates.select_units(to_select) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} templates with too high mean sd / noise ratio") return templates From c8b49d9e23fe2ecd8588f9c3fb78092a99aae83a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:36:08 +0000 Subject: [PATCH 20/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/iterative_hdbscan.py | 6 +----- src/spikeinterface/sortingcomponents/tools.py | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 3725778ceb..eba7bb1cb5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -145,11 +145,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): cleaning_kwargs["max_std_per_channel"] = max_std_per_channel ## Pre clean using templates (jitter) - cleaned_templates = clean_templates( - templates, - **cleaning_kwargs, - verbose=True - ) + cleaned_templates = clean_templates(templates, **cleaning_kwargs, verbose=True) mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) to_remove_ids = templates.unit_ids[~mask_keep_ids] to_remove_label_mask = np.isin(peak_labels, to_remove_ids) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index b603921687..3f2bc8728a 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -541,12 +541,12 @@ def clean_templates( remove_empty=True, mean_sd_ratio_threshold=3.0, max_std_per_channel=None, - verbose=True + verbose=True, ): """ Clean a Templates object by removing empty units and applying sparsity if provided. """ - + initial_ids = templates.unit_ids.copy() ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) From 131366b32d56d0b18a5a2670abbcda0bfbc54802 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 9 Dec 2025 11:51:41 +0100 Subject: [PATCH 21/38] Cleaning --- .../sortingcomponents/clustering/iterative_hdbscan.py | 7 ++++--- src/spikeinterface/sortingcomponents/tools.py | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 3725778ceb..f37cb78cb3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -139,16 +139,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): templates = templates.to_sparse(new_sparse_mask) cleaning_kwargs = params.get("pre_clean_templates", {}).copy() + cleaning_kwargs["verbose"] = verbose + cleaning_kwargs["max_std_per_channel"] = max_std_per_channel if "noise_levels" not in cleaning_kwargs: noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) cleaning_kwargs["noise_levels"] = noise_levels - cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + ## Pre clean using templates (jitter) cleaned_templates = clean_templates( templates, - **cleaning_kwargs, - verbose=True + **cleaning_kwargs ) mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) to_remove_ids = templates.unit_ids[~mask_keep_ids] diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index b603921687..6a244ff6be 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -607,12 +607,10 @@ def clean_templates( assert noise_levels is not None, "noise_levels must be provided if max_std_per_channel is given" to_select = [] n_before = len(templates.unit_ids) - all_ratios = [] for count, unit_id in enumerate(templates.unit_ids): old_index = np.where(unit_id == initial_ids)[0][0] mask = templates.sparsity.mask[count, :] sd_ratio = np.mean(max_std_per_channel[old_index][mask] / noise_levels[mask]) - all_ratios += [sd_ratio] if sd_ratio <= mean_sd_ratio_threshold: to_select += [unit_id] templates = templates.select_units(to_select) From 7619e92a4e1625e238d5251a356a1f6fad1a8d02 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:52:46 +0000 Subject: [PATCH 22/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/iterative_hdbscan.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index e710a6040c..5ce70bb979 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -144,12 +144,9 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if "noise_levels" not in cleaning_kwargs: noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) cleaning_kwargs["noise_levels"] = noise_levels - + ## Pre clean using templates (jitter) - cleaned_templates = clean_templates( - templates, - **cleaning_kwargs - ) + cleaned_templates = clean_templates(templates, **cleaning_kwargs) mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) to_remove_ids = templates.unit_ids[~mask_keep_ids] to_remove_label_mask = np.isin(peak_labels, to_remove_ids) From d790ad147bdd7ada5a2f7a5a24f58c74d91e55d8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Dec 2025 15:16:57 +0100 Subject: [PATCH 23/38] more pre clean clustering --- src/spikeinterface/sorters/internal/lupin.py | 10 ++- .../sorters/internal/spyking_circus2.py | 6 +- .../sorters/internal/tridesclous2.py | 1 + .../clustering/iterative_hdbscan.py | 18 +++-- .../clustering/iterative_isosplit.py | 75 ++++++++++++++----- 5 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 1825e82ba4..2743798e85 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -54,7 +54,7 @@ class LupinSorter(ComponentsBasedSorter): "clustering_recursive_depth": 3, "ms_before": 1.0, "ms_after": 2.5, - "sparsity_threshold": 1.5, + "template_sparsify_threshold": 1.5, "template_min_snr": 2.5, "template_max_jitter_ms": 0.2, "gather_mode": "memory", @@ -81,7 +81,7 @@ class LupinSorter(ComponentsBasedSorter): "clustering_recursive_depth": "Clustering recussivity", "ms_before": "Milliseconds before the spike peak for template matching", "ms_after": "Milliseconds after the spike peak for template matching", - "sparsity_threshold": "Threshold to sparsify templates before template matching", + "template_sparsify_threshold": "Threshold to sparsify templates before template matching", "template_min_snr": "Threshold to remove templates before template matching", "gather_mode": "How to accumalte spike in matching : memory/npy", "job_kwargs": "The famous and fabulous job_kwargs", @@ -233,6 +233,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] + clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"] + clustering_kwargs["clean_templates"]["template_min_snr"] = params["template_min_snr"] + clustering_kwargs["clean_templates"]["template_max_jitter_ms"] = params["template_max_jitter_ms"] + clustering_kwargs["noise_levels"] = noise_levels if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder @@ -291,7 +295,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this spasify more templates = clean_templates( templates, - sparsify_threshold=params["sparsity_threshold"], + sparsify_threshold=params["template_sparsify_threshold"], noise_levels=noise_levels, min_snr=params["template_min_snr"], max_jitter_ms=params["template_max_jitter_ms"], diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ebf19fafb8..1029d63730 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -309,7 +309,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Kept %d peaks for clustering" % len(selected_peaks)) cleaning_kwargs = params.get("cleaning", {}).copy() - cleaning_kwargs["noise_levels"] = noise_levels cleaning_kwargs["remove_empty"] = True if clustering_method in [ @@ -322,9 +321,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(seed=seed) clustering_params.update(peaks_svd=params["general"]) if clustering_method == "iterative-hdbscan": - clustering_params.update(pre_clean_templates=cleaning_kwargs) + clustering_params.update(clean_templates=cleaning_kwargs) + if clustering_method in ("iterative-hdbscan", "iterative-isosplit"): + clustering_params["noise_levels"] = noise_levels if debug: clustering_params["debug_folder"] = sorter_output_folder / "clustering" + _, peak_labels, more_outs = find_clusters_from_peaks( recording_w, diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index d955bcd7d4..a52fc92e86 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -194,6 +194,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["split"].update(params["clustering"]) if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder + clustering_kwargs["noise_levels"] = noise_levels # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 5ce70bb979..567dead058 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -32,6 +32,7 @@ class IterativeHDBSCANClustering: _default_params = { "peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0}, "seed": None, + "noise_levels": None, "split": { "split_radius_um": 75.0, "recursive": True, @@ -45,8 +46,9 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, - "pre_clean_templates": { - "sparsify_threshold": 1, + "clean_templates": { + "sparsify_threshold": 1., + "min_snr" : 2.5, "remove_empty": True, "max_jitter_ms": 0.2, }, @@ -136,16 +138,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): return_max_std_per_channel=True, ) + ## Pre clean using templates (jitter, sparsify_threshold) templates = templates.to_sparse(new_sparse_mask) - - cleaning_kwargs = params.get("pre_clean_templates", {}).copy() + cleaning_kwargs = params["clean_templates"].copy() cleaning_kwargs["verbose"] = verbose cleaning_kwargs["max_std_per_channel"] = max_std_per_channel - if "noise_levels" not in cleaning_kwargs: + if params["noise_levels"] is not None: + noise_levels = params["noise_levels"] + else: noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) - cleaning_kwargs["noise_levels"] = noise_levels - - ## Pre clean using templates (jitter) + cleaning_kwargs["noise_levels"] = noise_levels cleaned_templates = clean_templates(templates, **cleaning_kwargs) mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) to_remove_ids = templates.unit_ids[~mask_keep_ids] diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 5da3f297a7..5a44ff4148 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -13,6 +13,7 @@ from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd from spikeinterface.sortingcomponents.tools import clean_templates from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd +from spikeinterface.core.recording_tools import get_noise_levels class IterativeISOSPLITClustering: @@ -33,6 +34,7 @@ class IterativeISOSPLITClustering: _default_params = { "motion": None, "seed": None, + "noise_levels": None, "peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 120.0, "motion": None}, "pre_label": { "mode": "channel", @@ -60,8 +62,11 @@ class IterativeISOSPLITClustering: # "projection_mode": "pca", }, }, - "pre_clean_templates": { + "clean_templates": { "max_jitter_ms": 0.2, + "min_snr" : 2.5, + "sparsify_threshold" : 1.0, + "remove_empty": True, }, "merge_from_templates": { "similarity_metric": "l1", @@ -101,6 +106,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_before = params["peaks_svd"]["ms_before"] ms_after = params["peaks_svd"]["ms_after"] # radius_um = params["waveforms"]["radius_um"] + verbose = params["verbose"] debug_folder = params["debug_folder"] @@ -210,28 +216,57 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="average", ) - ## Pre clean using templates (jitter) - cleaned_templates = clean_templates( - dense_templates, - # sparsify_threshold=0.25, - sparsify_threshold=None, - # noise_levels=None, - # min_snr=None, - max_jitter_ms=params["pre_clean_templates"]["max_jitter_ms"], - # remove_empty=True, - remove_empty=False, - # sd_ratio_threshold=5.0, - # stds_at_peak=None, - ) - mask_keep_ids = np.isin(dense_templates.unit_ids, cleaned_templates.unit_ids) - to_remove_ids = dense_templates.unit_ids[~mask_keep_ids] + + ## Pre clean using templates (jitter, sparsify_threshold) + templates = dense_templates.to_sparse(template_sparse_mask) + cleaning_kwargs = params["clean_templates"].copy() + # cleaning_kwargs["verbose"] = verbose + cleaning_kwargs["verbose"] = True + # cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + if params["noise_levels"] is not None: + noise_levels = params["noise_levels"] + else: + noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) + cleaning_kwargs["noise_levels"] = noise_levels + cleaned_templates = clean_templates(templates, **cleaning_kwargs) + mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) + to_remove_ids = templates.unit_ids[~mask_keep_ids] to_remove_label_mask = np.isin(post_split_label, to_remove_ids) post_split_label[to_remove_label_mask] = -1 - dense_templates = cleaned_templates - template_sparse_mask = template_sparse_mask[mask_keep_ids, :] - - unit_ids = dense_templates.unit_ids + template_sparse_mask = cleaned_templates.sparsity.mask.copy() + dense_templates = cleaned_templates.to_dense() templates_array = dense_templates.templates_array + unit_ids = dense_templates.unit_ids + + + + + + # ## Pre clean using templates (jitter) + # cleaned_templates = clean_templates( + # dense_templates, + # # sparsify_threshold=0.25, + # sparsify_threshold=None, + # # noise_levels=None, + # # min_snr=None, + # max_jitter_ms=params["clean_templates"]["max_jitter_ms"], + # # remove_empty=True, + # remove_empty=False, + # # sd_ratio_threshold=5.0, + # # stds_at_peak=None, + # ) + # mask_keep_ids = np.isin(dense_templates.unit_ids, cleaned_templates.unit_ids) + # to_remove_ids = dense_templates.unit_ids[~mask_keep_ids] + # to_remove_label_mask = np.isin(post_split_label, to_remove_ids) + # post_split_label[to_remove_label_mask] = -1 + # dense_templates = cleaned_templates + # template_sparse_mask = template_sparse_mask[mask_keep_ids, :] + + # unit_ids = dense_templates.unit_ids + # templates_array = dense_templates.templates_array + + + if params["merge_from_features"] is not None: From 0736f55dbad452f83c343a7584cd3c1ac8d6ecec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:17:42 +0000 Subject: [PATCH 24/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 1 - .../clustering/iterative_hdbscan.py | 4 ++-- .../clustering/iterative_isosplit.py | 12 ++---------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 1029d63730..08ae71cf9e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -326,7 +326,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["noise_levels"] = noise_levels if debug: clustering_params["debug_folder"] = sorter_output_folder / "clustering" - _, peak_labels, more_outs = find_clusters_from_peaks( recording_w, diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 567dead058..d7f68de902 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -47,8 +47,8 @@ class IterativeHDBSCANClustering: }, }, "clean_templates": { - "sparsify_threshold": 1., - "min_snr" : 2.5, + "sparsify_threshold": 1.0, + "min_snr": 2.5, "remove_empty": True, "max_jitter_ms": 0.2, }, diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 5a44ff4148..7598a6a095 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -64,8 +64,8 @@ class IterativeISOSPLITClustering: }, "clean_templates": { "max_jitter_ms": 0.2, - "min_snr" : 2.5, - "sparsify_threshold" : 1.0, + "min_snr": 2.5, + "sparsify_threshold": 1.0, "remove_empty": True, }, "merge_from_templates": { @@ -216,7 +216,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="average", ) - ## Pre clean using templates (jitter, sparsify_threshold) templates = dense_templates.to_sparse(template_sparse_mask) cleaning_kwargs = params["clean_templates"].copy() @@ -238,10 +237,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): templates_array = dense_templates.templates_array unit_ids = dense_templates.unit_ids - - - - # ## Pre clean using templates (jitter) # cleaned_templates = clean_templates( # dense_templates, @@ -265,9 +260,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # unit_ids = dense_templates.unit_ids # templates_array = dense_templates.templates_array - - - if params["merge_from_features"] is not None: merge_from_features_kwargs = params["merge_from_features"].copy() From f9bca31ef4d1c039b980eb99c15fcc9d43aab9ee Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 9 Dec 2025 21:08:27 +0100 Subject: [PATCH 25/38] Broken SC2 left by Sam --- src/spikeinterface/sorters/internal/spyking_circus2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 08ae71cf9e..e63f895733 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -320,10 +320,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(verbose=verbose) clustering_params.update(seed=seed) clustering_params.update(peaks_svd=params["general"]) - if clustering_method == "iterative-hdbscan": + if clustering_method in ["iterative-hdbscan", "iterative-isosplit"]: clustering_params.update(clean_templates=cleaning_kwargs) - if clustering_method in ("iterative-hdbscan", "iterative-isosplit"): clustering_params["noise_levels"] = noise_levels + if debug: clustering_params["debug_folder"] = sorter_output_folder / "clustering" @@ -390,7 +390,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): cleaning_kwargs["max_std_per_channel"] = max_std_per_channel - templates = clean_templates(templates, **cleaning_kwargs) + templates = clean_templates(templates, noise_levels=noise_levels, **cleaning_kwargs) if verbose: print("Kept %d clean clusters" % len(templates.unit_ids)) From ca5a0778606599bfe7cd49cd18a290776ba835f8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 9 Dec 2025 21:10:33 +0100 Subject: [PATCH 26/38] Verbose --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e63f895733..c37f626daa 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -389,7 +389,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): del more_outs cleaning_kwargs["max_std_per_channel"] = max_std_per_channel - + cleaning_kwargs["verbose"] = verbose templates = clean_templates(templates, noise_levels=noise_levels, **cleaning_kwargs) if verbose: diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index ac5d36982e..922abd2adb 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -541,7 +541,7 @@ def clean_templates( remove_empty=True, mean_sd_ratio_threshold=3.0, max_std_per_channel=None, - verbose=True, + verbose=False, ): """ Clean a Templates object by removing empty units and applying sparsity if provided. From 0b80955e6a2d15f3b784b47868bb8d3c796bba50 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 9 Dec 2025 21:44:41 +0100 Subject: [PATCH 27/38] trying to get old behavior --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e63f895733..56c858730f 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 2.5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3}, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, From a869f89fded7865d90a9ca92abf2b8fd1a024d0a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 10 Dec 2025 14:35:05 +0100 Subject: [PATCH 28/38] more clean after clustering --- .../benchmark/benchmark_clustering.py | 4 +- .../benchmark/benchmark_matching.py | 4 +- src/spikeinterface/sorters/internal/lupin.py | 15 +- .../sorters/internal/spyking_circus2.py | 12 ++ .../sorters/internal/tridesclous2.py | 8 +- .../clustering/iterative_isosplit.py | 203 ++---------------- .../sortingcomponents/clustering/tools.py | 30 +++ 7 files changed, 75 insertions(+), 201 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index 3885fe073c..31e840c6ee 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -29,9 +29,9 @@ def __init__(self, recording, gt_sorting, params, indices, peaks, exhaustive_gt= self.method_kwargs = params["method_kwargs"] self.result = {} - def run(self, **job_kwargs): + def run(self, verbose=True, **job_kwargs): labels, peak_labels = find_clusters_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs + self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, verbose=verbose, job_kwargs=job_kwargs ) self.result["peak_labels"] = peak_labels diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 8115870876..4b37518c28 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -26,9 +26,9 @@ def __init__(self, recording, gt_sorting, params): self.method_kwargs = params["method_kwargs"] self.result = {} - def run(self, **job_kwargs): + def run(self, verbose=True, **job_kwargs): spikes = find_spikes_from_templates( - self.recording, self.templates, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs + self.recording, self.templates, method=self.method, method_kwargs=self.method_kwargs, verbose=verbose, job_kwargs=job_kwargs ) unit_ids = self.templates.unit_ids sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 2743798e85..b80baf59af 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -55,8 +55,9 @@ class LupinSorter(ComponentsBasedSorter): "ms_before": 1.0, "ms_after": 2.5, "template_sparsify_threshold": 1.5, - "template_min_snr": 2.5, + "template_min_snr_ptp": 4., "template_max_jitter_ms": 0.2, + "min_firing_rate": 0.1, "gather_mode": "memory", "job_kwargs": {}, "seed": None, @@ -82,7 +83,9 @@ class LupinSorter(ComponentsBasedSorter): "ms_before": "Milliseconds before the spike peak for template matching", "ms_after": "Milliseconds after the spike peak for template matching", "template_sparsify_threshold": "Threshold to sparsify templates before template matching", - "template_min_snr": "Threshold to remove templates before template matching", + "template_min_snr_ptp": "Threshold to remove templates before template matching", + "template_max_jitter_ms": "Threshold on jitters to remove templates before template matching", + "min_firing_rate": "To remove small cluster in size before template matching", "gather_mode": "How to accumalte spike in matching : memory/npy", "job_kwargs": "The famous and fabulous job_kwargs", "seed": "Seed for random number", @@ -234,9 +237,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"] - clustering_kwargs["clean_templates"]["template_min_snr"] = params["template_min_snr"] - clustering_kwargs["clean_templates"]["template_max_jitter_ms"] = params["template_max_jitter_ms"] + clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] + clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] clustering_kwargs["noise_levels"] = noise_levels + clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder @@ -297,7 +302,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates, sparsify_threshold=params["template_sparsify_threshold"], noise_levels=noise_levels, - min_snr=params["template_min_snr"], + min_snr=params["template_min_snr_ptp"], max_jitter_ms=params["template_max_jitter_ms"], remove_empty=True, ) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4f3c9cc7d8..ca693e1a9d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -40,6 +40,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, "cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3}, + "min_firing_rate" : 0.1, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -103,6 +104,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering import find_clusters_from_peaks + from spikeinterface.sortingcomponents.clustering.tools import remove_small_cluster from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction from spikeinterface.sortingcomponents.tools import clean_templates @@ -388,9 +390,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): del more_outs + before_clean_ids = templates.unit_ids.copy() cleaning_kwargs["max_std_per_channel"] = max_std_per_channel cleaning_kwargs["verbose"] = verbose templates = clean_templates(templates, noise_levels=noise_levels, **cleaning_kwargs) + remove_peak_mask = ~np.isin(peak_labels, templates.unit_ids) + peak_labels[remove_peak_mask] = -1 + + if params["min_firing_rate"] is not None: + peak_labels, to_keep = remove_small_cluster(recording_w, selected_peaks, peak_labels, + min_firing_rate=params["min_firing_rate"], + subsampling_factor=peaks.size / selected_peaks.size, + verbose=verbose) + templates = templates.select_units(to_keep) if verbose: print("Kept %d clean clusters" % len(templates.unit_ids)) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index a52fc92e86..6e2bd7eb0e 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -51,12 +51,13 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "clustering": { "recursive_depth": 3, }, + "min_firing_rate": 0.1, "templates": { "ms_before": 2.0, "ms_after": 3.0, "max_spikes_per_unit": 400, "sparsity_threshold": 1.5, - "min_snr": 2.5, + "min_snr": 3.5, "radius_um": 100.0, "max_jitter_ms": 0.2, }, @@ -195,6 +196,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder clustering_kwargs["noise_levels"] = noise_levels + clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size + # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None @@ -263,7 +267,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): is_in_uV=False, ) - # this spasify more + # this clean and spasify more templates = clean_templates( templates, sparsify_threshold=params["templates"]["sparsity_threshold"], diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 7598a6a095..b4af99bec0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -10,7 +10,7 @@ merge_peak_labels_from_templates, merge_peak_labels_from_features, ) -from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd +from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd, remove_small_cluster from spikeinterface.sortingcomponents.tools import clean_templates from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd from spikeinterface.core.recording_tools import get_noise_levels @@ -76,7 +76,8 @@ class IterativeISOSPLITClustering: "merge_from_features": None, # "merge_from_features": {"merge_radius_um": 60.0}, "clean": { - "minimum_cluster_size": 10, + "min_firing_rate": 0.1, + "subsampling_factor": None, }, "debug_folder": None, "verbose": True, @@ -219,8 +220,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ## Pre clean using templates (jitter, sparsify_threshold) templates = dense_templates.to_sparse(template_sparse_mask) cleaning_kwargs = params["clean_templates"].copy() - # cleaning_kwargs["verbose"] = verbose - cleaning_kwargs["verbose"] = True + cleaning_kwargs["verbose"] = verbose + # cleaning_kwargs["verbose"] = True # cleaning_kwargs["max_std_per_channel"] = max_std_per_channel if params["noise_levels"] is not None: noise_levels = params["noise_levels"] @@ -309,22 +310,15 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparsity = ChannelSparsity(template_sparse_mask, unit_ids, recording.channel_ids) templates = dense_templates.to_sparse(sparsity) - # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") - - # new_peaks = peaks.copy() - # new_peaks["sample_index"] -= peak_shifts - # clean very small cluster before peeler - post_clean_label = post_merge_label2.copy() - minimum_cluster_size = params["clean"]["minimum_cluster_size"] - labels_set, count = np.unique(post_clean_label, return_counts=True) - to_remove = labels_set[count < minimum_cluster_size] - mask = np.isin(post_clean_label, to_remove) - post_clean_label[mask] = -1 - final_peak_labels = post_clean_label - labels_set = np.unique(final_peak_labels) - labels_set = labels_set[labels_set >= 0] - templates = templates.select_units(labels_set) + if params["clean"]["subsampling_factor"] is not None and params["clean"]["min_firing_rate"] is not None: + final_peak_labels, to_keep = remove_small_cluster(recording, peaks, post_merge_label2, + min_firing_rate=params["clean"]["min_firing_rate"], + subsampling_factor=params["clean"]["subsampling_factor"], + verbose=verbose, + ) + templates = templates.select_units(to_keep) + labels_set = templates.unit_ids more_outs = dict( @@ -332,174 +326,3 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ) return labels_set, final_peak_labels, more_outs - # _default_params = { - # "clean": { - # "minimum_cluster_size": 10, - # }, - # } - - # @classmethod - # def main_function(cls, recording, peaks, params, job_kwargs=dict()): - - # split_radius_um = params["split"].pop("split_radius_um", 40) - # peaks_svd = params["peaks_svd"] - # motion = peaks_svd["motion"] - # ms_before = peaks_svd.get("ms_before", 0.5) - # ms_after = peaks_svd.get("ms_after", 1.5) - # verbose = params.get("verbose", True) - # split = params["split"] - # seed = params["seed"] - # job_kwargs = params.get("job_kwargs", dict()) - # debug_folder = params.get("debug_folder", None) - - # if debug_folder is not None: - # debug_folder = Path(debug_folder).absolute() - # debug_folder.mkdir(exist_ok=True) - # peaks_svd.update(folder=debug_folder / "features") - - # motion_aware = motion is not None - # peaks_svd.update(motion_aware=motion_aware) - - # if seed is not None: - # peaks_svd.update(seed=seed) - # split["method_kwargs"].update(seed=seed) - - # outs = extract_peaks_svd( - # recording, - # peaks, - # **peaks_svd, - # **job_kwargs, - # ) - - # if motion_aware: - # # also return peaks with new channel index - # peaks_svd, sparse_mask, svd_model, moved_peaks = outs - # peaks = moved_peaks - # else: - # peaks_svd, sparse_mask, svd_model = outs - - # if debug_folder is not None: - # np.save(debug_folder / "sparse_mask.npy", sparse_mask) - # np.save(debug_folder / "peaks.npy", peaks) - - # split["method_kwargs"].update(waveforms_sparse_mask = sparse_mask) - # neighbours_mask = get_channel_distances(recording) <= split_radius_um - # split["method_kwargs"].update(neighbours_mask = neighbours_mask) - - # if debug_folder is not None: - # split.update(debug_folder = debug_folder / "split") - - # peak_labels = split_clusters( - # peaks["channel_index"], - # recording, - # {"peaks": peaks, "sparse_tsvd": peaks_svd}, - # method="local_feature_clustering", - # **split, - # **job_kwargs, - # ) - - # templates, new_sparse_mask = get_templates_from_peaks_and_svd( - # recording, - # peaks, - # peak_labels, - # ms_before, - # ms_after, - # svd_model, - # peaks_svd, - # sparse_mask, - # operator="average", - # ) - - # labels = templates.unit_ids - - # if verbose: - # print("Kept %d raw clusters" % len(labels)) - - # if params["merge_from_features"] is not None: - - # merge_features_kwargs = params["merge_from_features"].copy() - # merge_radius_um = merge_features_kwargs.pop("merge_radius_um") - - # peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_features( - # peaks, - # peak_labels, - # templates.unit_ids, - # templates.templates_array, - # sparse_mask, - # recording, - # {"peaks": peaks, "sparse_tsvd": peaks_svd}, - # radius_um=merge_radius_um, - # method="project_distribution", - # method_kwargs=dict( - # feature_name="sparse_tsvd", - # waveforms_sparse_mask=sparse_mask, - # **merge_features_kwargs - # ), - # **job_kwargs, - # ) - - # templates = Templates( - # templates_array=merge_template_array, - # sampling_frequency=recording.sampling_frequency, - # nbefore=templates.nbefore, - # sparsity_mask=None, - # channel_ids=recording.channel_ids, - # unit_ids=new_unit_ids, - # probe=recording.get_probe(), - # is_in_uV=False, - # ) - - # if params["merge_from_templates"] is not None: - # peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates( - # peaks, - # peak_labels, - # templates.unit_ids, - # templates.templates_array, - # new_sparse_mask, - # **params["merge_from_templates"], - # ) - - # templates = Templates( - # templates_array=merge_template_array, - # sampling_frequency=recording.sampling_frequency, - # nbefore=templates.nbefore, - # sparsity_mask=None, - # channel_ids=recording.channel_ids, - # unit_ids=new_unit_ids, - # probe=recording.get_probe(), - # is_in_uV=False, - # ) - - # labels = templates.unit_ids - - # if debug_folder is not None: - # templates.to_zarr(folder_path=debug_folder / "dense_templates") - - # if verbose: - # print("Kept %d non-duplicated clusters" % len(labels)) - - # # sparsity = ChannelSparsity(template_sparse_mask, unit_ids, recording.channel_ids) - # # templates = dense_templates.to_sparse(sparsity) - - # # # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") - - # # # new_peaks = peaks.copy() - # # # new_peaks["sample_index"] -= peak_shifts - - # # # clean very small cluster before peeler - # # post_clean_label = post_merge_label2.copy() - # # minimum_cluster_size = params["clean"]["minimum_cluster_size"] - # # labels_set, count = np.unique(post_clean_label, return_counts=True) - # # to_remove = labels_set[count < minimum_cluster_size] - # # mask = np.isin(post_clean_label, to_remove) - # # post_clean_label[mask] = -1 - # # final_peak_labels = post_clean_label - # # labels_set = np.unique(final_peak_labels) - # # labels_set = labels_set[labels_set >= 0] - # # templates = templates.select_units(labels_set) - # # labels_set = templates.unit_ids - - # more_outs = dict( - # templates=templates, - # ) - # return labels, peak_labels, more_outs diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index f6a6029d0d..9fe46d4936 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -326,3 +326,33 @@ def get_templates_from_peaks_and_svd( return dense_templates, final_sparsity_mask, max_std_per_channel else: return dense_templates, final_sparsity_mask + + +def remove_small_cluster(recording, peaks, peak_labels, min_firing_rate=0.1, subsampling_factor=None, verbose=False): + """ + Remove clusters too small in size (spike count) given a min firing rate and a subsampling factor. + """ + + if subsampling_factor is None: + if verbose: + print("remove_small_cluster(): subsampling_factor is not set, assuming 1") + subsampling_factor = 1 + + min_spike_count = int(recording.get_total_duration() * min_firing_rate / subsampling_factor) + + peak_labels = peak_labels.copy() + labels_set, count = np.unique(peak_labels, return_counts=True) + cluster_mask = count < min_spike_count + to_remove = labels_set[cluster_mask] + to_keep = labels_set[~cluster_mask] + peak_mask = np.isin(peak_labels, to_remove) + peak_labels[peak_mask] = -1 + + to_keep = to_keep[to_keep >= 0] + + if verbose: + print(f"remove_small_cluster: kept {to_keep.size} removed {to_remove.size} (min_spike_count {min_spike_count})") + + return peak_labels, to_keep + + From fb20f7fb53ab3689d16c0ae285fa3b2a434fd842 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:35:41 +0000 Subject: [PATCH 29/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_clustering.py | 7 ++++++- src/spikeinterface/benchmark/benchmark_matching.py | 7 ++++++- src/spikeinterface/sorters/internal/lupin.py | 2 +- .../sorters/internal/spyking_circus2.py | 14 +++++++++----- .../sorters/internal/tridesclous2.py | 1 - .../clustering/iterative_isosplit.py | 14 ++++++++------ .../sortingcomponents/clustering/tools.py | 12 ++++++------ 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index 31e840c6ee..6cf6cbe7a3 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -31,7 +31,12 @@ def __init__(self, recording, gt_sorting, params, indices, peaks, exhaustive_gt= def run(self, verbose=True, **job_kwargs): labels, peak_labels = find_clusters_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, verbose=verbose, job_kwargs=job_kwargs + self.recording, + self.peaks, + method=self.method, + method_kwargs=self.method_kwargs, + verbose=verbose, + job_kwargs=job_kwargs, ) self.result["peak_labels"] = peak_labels diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 4b37518c28..02e2146b5a 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -28,7 +28,12 @@ def __init__(self, recording, gt_sorting, params): def run(self, verbose=True, **job_kwargs): spikes = find_spikes_from_templates( - self.recording, self.templates, method=self.method, method_kwargs=self.method_kwargs, verbose=verbose, job_kwargs=job_kwargs + self.recording, + self.templates, + method=self.method, + method_kwargs=self.method_kwargs, + verbose=verbose, + job_kwargs=job_kwargs, ) unit_ids = self.templates.unit_ids sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index b80baf59af..26fd952f6f 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -55,7 +55,7 @@ class LupinSorter(ComponentsBasedSorter): "ms_before": 1.0, "ms_after": 2.5, "template_sparsify_threshold": 1.5, - "template_min_snr_ptp": 4., + "template_min_snr_ptp": 4.0, "template_max_jitter_ms": 0.2, "min_firing_rate": 0.1, "gather_mode": "memory", diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ca693e1a9d..49418a1706 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -40,7 +40,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, "cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3}, - "min_firing_rate" : 0.1, + "min_firing_rate": 0.1, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -398,10 +398,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): peak_labels[remove_peak_mask] = -1 if params["min_firing_rate"] is not None: - peak_labels, to_keep = remove_small_cluster(recording_w, selected_peaks, peak_labels, - min_firing_rate=params["min_firing_rate"], - subsampling_factor=peaks.size / selected_peaks.size, - verbose=verbose) + peak_labels, to_keep = remove_small_cluster( + recording_w, + selected_peaks, + peak_labels, + min_firing_rate=params["min_firing_rate"], + subsampling_factor=peaks.size / selected_peaks.size, + verbose=verbose, + ) templates = templates.select_units(to_keep) if verbose: diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 6e2bd7eb0e..68b27c6b0b 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -199,7 +199,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size - # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None # if not have_sisosplit6: diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index b4af99bec0..cf821e8c67 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -312,11 +312,14 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # clean very small cluster before peeler if params["clean"]["subsampling_factor"] is not None and params["clean"]["min_firing_rate"] is not None: - final_peak_labels, to_keep = remove_small_cluster(recording, peaks, post_merge_label2, - min_firing_rate=params["clean"]["min_firing_rate"], - subsampling_factor=params["clean"]["subsampling_factor"], - verbose=verbose, - ) + final_peak_labels, to_keep = remove_small_cluster( + recording, + peaks, + post_merge_label2, + min_firing_rate=params["clean"]["min_firing_rate"], + subsampling_factor=params["clean"]["subsampling_factor"], + verbose=verbose, + ) templates = templates.select_units(to_keep) labels_set = templates.unit_ids @@ -325,4 +328,3 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): templates=templates, ) return labels_set, final_peak_labels, more_outs - diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 9fe46d4936..406aee1bd4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -337,7 +337,7 @@ def remove_small_cluster(recording, peaks, peak_labels, min_firing_rate=0.1, sub if verbose: print("remove_small_cluster(): subsampling_factor is not set, assuming 1") subsampling_factor = 1 - + min_spike_count = int(recording.get_total_duration() * min_firing_rate / subsampling_factor) peak_labels = peak_labels.copy() @@ -349,10 +349,10 @@ def remove_small_cluster(recording, peaks, peak_labels, min_firing_rate=0.1, sub peak_labels[peak_mask] = -1 to_keep = to_keep[to_keep >= 0] - - if verbose: - print(f"remove_small_cluster: kept {to_keep.size} removed {to_remove.size} (min_spike_count {min_spike_count})") - - return peak_labels, to_keep + if verbose: + print( + f"remove_small_cluster: kept {to_keep.size} removed {to_remove.size} (min_spike_count {min_spike_count})" + ) + return peak_labels, to_keep From 8b6f3d9cf50a1e388b2a659859213248ae1e9e86 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 10 Dec 2025 15:03:13 +0100 Subject: [PATCH 30/38] oups --- .../sortingcomponents/clustering/iterative_isosplit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index b4af99bec0..5fa5208c95 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -318,6 +318,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): verbose=verbose, ) templates = templates.select_units(to_keep) + else: + final_peak_labels = post_merge_label2 labels_set = templates.unit_ids From ed9bf21dab0b54a0c673edaf5187114700614b88 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 10 Dec 2025 15:50:29 +0100 Subject: [PATCH 31/38] Flatten tridesclous2 params --- .../internal/tests/test_tridesclous2.py | 2 +- .../sorters/internal/tridesclous2.py | 264 +++++++++++++----- 2 files changed, 192 insertions(+), 74 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py index 1f1f109d28..01284d7ee6 100644 --- a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py @@ -17,7 +17,7 @@ def test_with_numpy_gather(self): output_folder = self.cache_folder / sorter_name sorter_params = self.SorterClass.default_params() - sorter_params["matching"]["gather_mode"] = "npy" + sorter_params["gather_mode"] = "npy" sorting = run_sorter( sorter_name, diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 68b27c6b0b..f572bcd565 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -29,64 +29,124 @@ class Tridesclous2Sorter(ComponentsBasedSorter): sorter_name = "tridesclous2" + # _default_params = { + # "apply_preprocessing": True, + # "apply_motion_correction": False, + # "motion_correction": {"preset": "dredge_fast"}, + # "cache_preprocessing_mode": "auto", + # "waveforms": { + # "ms_before": 0.5, + # "ms_after": 1.5, + # "radius_um": 120.0, + # }, + # "filtering": { + # "freq_min": 150.0, + # "freq_max": 6000.0, + # "ftype": "bessel", + # "filter_order": 2, + # }, + # "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, + # "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, + # "svd": {"n_components": 5}, + # "clustering": { + # "recursive_depth": 3, + # }, + # "min_firing_rate": 0.1, + # "templates": { + # "ms_before": 2.0, + # "ms_after": 3.0, + # "max_spikes_per_unit": 400, + # "sparsity_threshold": 1.5, + # "min_snr": 3.5, + # "radius_um": 100.0, + # "max_jitter_ms": 0.2, + # }, + # "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, + # "job_kwargs": {}, + # "save_array": True, + # "debug": False, + # } + + # _params_description = { + # "apply_preprocessing": "Apply internal preprocessing or not", + # "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", + # "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", + # "filtering": "A dictonary containing filtering params: freq_min, freq_max", + # "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", + # "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", + # "svd": "A dictonary containing svd params: n_components", + # "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", + # "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", + # "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", + # "job_kwargs": "A dictionary containing job kwargs", + # "save_array": "Save or not intermediate arrays", + # } + + _default_params = { "apply_preprocessing": True, "apply_motion_correction": False, - "motion_correction": {"preset": "dredge_fast"}, + "motion_correction_preset": "dredge_fast", + "clustering_ms_before": 0.5, + "clustering_ms_after": 1.5, + "detection_radius_um": 150.0, + "features_radius_um": 75.0, + "template_radius_um": 100.0, + "freq_min": 150.0, + "freq_max": 6000.0, "cache_preprocessing_mode": "auto", - "waveforms": { - "ms_before": 0.5, - "ms_after": 1.5, - "radius_um": 120.0, - }, - "filtering": { - "freq_min": 150.0, - "freq_max": 6000.0, - "ftype": "bessel", - "filter_order": 2, - }, - "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, - "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "svd": {"n_components": 5}, - "clustering": { - "recursive_depth": 3, - }, + "peak_sign": "neg", + "detect_threshold": 5., + "n_peaks_per_channel": 5000, + "n_svd_components_per_channel": 5, + "n_pca_features": 6, + "clustering_recursive_depth": 3, + "ms_before": 2.0, + "ms_after": 3.0, + "template_sparsify_threshold": 1.5, + "template_min_snr_ptp": 3.5, + "template_max_jitter_ms": 0.2, "min_firing_rate": 0.1, - "templates": { - "ms_before": 2.0, - "ms_after": 3.0, - "max_spikes_per_unit": 400, - "sparsity_threshold": 1.5, - "min_snr": 3.5, - "radius_um": 100.0, - "max_jitter_ms": 0.2, - }, - "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, + "gather_mode": "memory", "job_kwargs": {}, + "seed": None, "save_array": True, "debug": False, } _params_description = { "apply_preprocessing": "Apply internal preprocessing or not", - "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", - "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", - "filtering": "A dictonary containing filtering params: freq_min, freq_max", - "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", - "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", - "svd": "A dictonary containing svd params: n_components", - "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", - "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", - "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", - "job_kwargs": "A dictionary containing job kwargs", - "save_array": "Save or not intermediate arrays", + "apply_motion_correction": "Apply motion correction or not", + "motion_correction_preset": "Motion correction preset", + "clustering_ms_before": "Milliseconds before the spike peak for clustering", + "clustering_ms_after": "Milliseconds after the spike peak for clustering", + "radius_um": "Radius for sparsity", + "freq_min": "Low frequency", + "freq_max": "High frequency", + "peak_sign": "Sign of peaks neg/pos/both", + "detect_threshold": "Treshold for peak detection", + "n_peaks_per_channel": "Number of spike per channel for clustering", + "n_svd_components_per_channel": "Number of SVD components per channel for clustering", + "n_pca_features": "Secondary PCA features reducation before local isosplit", + "clustering_recursive_depth": "Clustering recussivity", + "ms_before": "Milliseconds before the spike peak for template matching", + "ms_after": "Milliseconds after the spike peak for template matching", + "template_sparsify_threshold": "Threshold to sparsify templates before template matching", + "template_min_snr_ptp": "Threshold to remove templates before template matching", + "template_max_jitter_ms": "Threshold on jitters to remove templates before template matching", + "min_firing_rate": "To remove small cluster in size before template matching", + "gather_mode": "How to accumalte spike in matching : memory/npy", + "job_kwargs": "The famous and fabulous job_kwargs", + "seed": "Seed for random number", + "save_array": "Save or not intermediate arrays in the folder", + "debug": "Save debug files", } handle_multi_segment = True @classmethod def get_sorter_version(cls): - return "2025.11" + return "2025.12" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -103,6 +163,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs = fix_job_kwargs(job_kwargs) job_kwargs["progress_bar"] = verbose + seed = params["seed"] + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() @@ -131,7 +193,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("Done correct_motion()") - recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20.0, dtype="float32") + # recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20.0, dtype="float32") + recording = bandpass_filter( + recording_raw, + freq_min=params["freq_min"], + freq_max=params["freq_max"], + ftype="bessel", + filter_order=2, + margin_ms=20.0, + dtype="float32", + ) if apply_cmr: recording = common_reference(recording) @@ -170,8 +241,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): cache_info = None # detection - detection_params = params["detection"].copy() - detection_params["noise_levels"] = noise_levels + # detection_params = params["detection"].copy() + # detection_params["noise_levels"] = noise_levels + detection_params = dict( + peak_sign=params["peak_sign"], + detect_threshold=params["detect_threshold"], + exclude_sweep_ms=1.5, + radius_um=params["detection_radius_um"], + ) + all_peaks = detect_peaks( recording, method="locally_exclusive", method_kwargs=detection_params, job_kwargs=job_kwargs ) @@ -180,24 +258,25 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print(f"detect_peaks(): {len(all_peaks)} peaks found") # selection - selection_params = params["selection"].copy() - n_peaks = params["selection"]["n_peaks_per_channel"] * num_chans - n_peaks = max(selection_params["min_n_peaks"], n_peaks) + # selection_params = params["selection"].copy() + # n_peaks = params["selection"]["n_peaks_per_channel"] * num_chans + # n_peaks = max(selection_params["min_n_peaks"], n_peaks) + n_peaks = max(params["n_peaks_per_channel"] * num_chans, 20_000) peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks) if verbose: print(f"select_peaks(): {len(peaks)} peaks kept for clustering") # routing clustering params into the big IterativeISOSPLITClustering params tree - clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) - clustering_kwargs["peaks_svd"].update(params["waveforms"]) - clustering_kwargs["peaks_svd"].update(params["svd"]) - clustering_kwargs["split"].update(params["clustering"]) - if params["debug"]: - clustering_kwargs["debug_folder"] = sorter_output_folder - clustering_kwargs["noise_levels"] = noise_levels - clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] - clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size + # clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) + # clustering_kwargs["peaks_svd"].update(params["waveforms"]) + # clustering_kwargs["peaks_svd"].update(params["svd"]) + # clustering_kwargs["split"].update(params["clustering"]) + # if params["debug"]: + # clustering_kwargs["debug_folder"] = sorter_output_folder + # clustering_kwargs["noise_levels"] = noise_levels + # clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] + # clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None @@ -206,6 +285,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" # ) + # Clustering + clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) + clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] + clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] + clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"] + clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] + clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] + clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] + clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"] + clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] + clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] + clustering_kwargs["noise_levels"] = noise_levels + clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size + + if params["debug"]: + clustering_kwargs["debug_folder"] = sorter_output_folder + unit_ids, clustering_label, more_outs = find_clusters_from_peaks( recording, peaks, @@ -231,18 +328,28 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_for_peeler = recording # preestimate the sparsity unsing peaks channel + # spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) + # sparsity, unit_locations = compute_sparsity_from_peaks_and_label( + # kept_peaks, + # spike_vector["unit_index"], + # sorting_pre_peeler.unit_ids, + # recording, + # params["templates"]["radius_um"], + # ) spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) sparsity, unit_locations = compute_sparsity_from_peaks_and_label( - kept_peaks, - spike_vector["unit_index"], - sorting_pre_peeler.unit_ids, - recording, - params["templates"]["radius_um"], + kept_peaks, spike_vector["unit_index"], sorting_pre_peeler.unit_ids, recording, params["template_radius_um"] ) + # we recompute the template even if the clustering give it already because we use different ms_before/ms_after - nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) - nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) + # nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) + # nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) + ms_before = params["ms_before"] + ms_after = params["ms_after"] + nbefore = int(ms_before * sampling_frequency / 1000.0) + nafter = int(ms_after * sampling_frequency / 1000.0) + templates_array = estimate_templates_with_accumulator( recording_for_peeler, @@ -266,31 +373,42 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): is_in_uV=False, ) - # this clean and spasify more + # this clean and sparsify more templates = clean_templates( templates, - sparsify_threshold=params["templates"]["sparsity_threshold"], + # sparsify_threshold=params["templates"]["sparsity_threshold"], + # noise_levels=noise_levels, + # min_snr=params["templates"]["min_snr"], + # max_jitter_ms=params["templates"]["max_jitter_ms"], + # remove_empty=True, + sparsify_threshold=params["template_sparsify_threshold"], noise_levels=noise_levels, - min_snr=params["templates"]["min_snr"], - max_jitter_ms=params["templates"]["max_jitter_ms"], + min_snr=params["template_min_snr_ptp"], + max_jitter_ms=params["template_max_jitter_ms"], remove_empty=True, ) ## peeler - matching_method = params["matching"].pop("method") - gather_mode = params["matching"].pop("gather_mode", "memory") - matching_params = params["matching"].get("matching_kwargs", {}).copy() - if matching_method in ("tdc-peeler",): - matching_params["noise_levels"] = noise_levels + # matching_method = params["matching"].pop("method") + # gather_mode = params["matching"].pop("gather_mode", "memory") + # matching_params = params["matching"].get("matching_kwargs", {}).copy() + # if matching_method in ("tdc-peeler",): + # matching_params["noise_levels"] = noise_levels + + # pipeline_kwargs = dict(gather_mode=gather_mode) + # if gather_mode == "npy": + # pipeline_kwargs["folder"] = sorter_output_folder / "matching" + gather_mode = params["gather_mode"] pipeline_kwargs = dict(gather_mode=gather_mode) if gather_mode == "npy": pipeline_kwargs["folder"] = sorter_output_folder / "matching" + method_kwargs = dict(noise_levels=noise_levels) spikes = find_spikes_from_templates( recording_for_peeler, templates, - method=matching_method, - method_kwargs=matching_params, + method="tdc-peeler", + method_kwargs=method_kwargs, pipeline_kwargs=pipeline_kwargs, job_kwargs=job_kwargs, ) From a4224c30e534ce66aedd23538371524bf85bd364 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 11 Dec 2025 11:19:46 +0100 Subject: [PATCH 32/38] Better variable naming --- src/spikeinterface/sorters/internal/lupin.py | 4 ++-- .../sorters/internal/tridesclous2.py | 4 ++-- .../clustering/iterative_hdbscan.py | 19 ++++++++++++++++++- .../clustering/iterative_isosplit.py | 8 ++++---- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 26fd952f6f..0b2bb6e74e 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -240,8 +240,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] clustering_kwargs["noise_levels"] = noise_levels - clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] - clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size + clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 68b27c6b0b..0a7cb5f826 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -196,8 +196,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder clustering_kwargs["noise_levels"] = noise_levels - clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] - clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size + clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index d7f68de902..fffce80687 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -9,7 +9,7 @@ from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters -from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd +from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd, remove_small_cluster from spikeinterface.sortingcomponents.tools import clean_templates from spikeinterface.core.recording_tools import get_noise_levels @@ -54,6 +54,10 @@ class IterativeHDBSCANClustering: }, "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), "merge_from_features": None, + "clean_low_firing": { + "min_firing_rate": 0.1, + "subsampling_factor": None, + }, "debug_folder": None, "verbose": True, } @@ -182,6 +186,19 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): is_in_uV=False, ) + + # clean very small cluster before peeler + if params["clean_low_firing"]["subsampling_factor"] is not None and params["clean_low_firing"]["min_firing_rate"] is not None: + peak_labels, to_keep = remove_small_cluster( + recording, + peaks, + peak_labels, + min_firing_rate=params["clean_low_firing"]["min_firing_rate"], + subsampling_factor=params["clean_low_firing"]["subsampling_factor"], + verbose=verbose, + ) + templates = templates.select_units(to_keep) + labels = templates.unit_ids if debug_folder is not None: diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index c294564558..625d69725c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -75,7 +75,7 @@ class IterativeISOSPLITClustering: }, "merge_from_features": None, # "merge_from_features": {"merge_radius_um": 60.0}, - "clean": { + "clean_low_firing": { "min_firing_rate": 0.1, "subsampling_factor": None, }, @@ -311,13 +311,13 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): templates = dense_templates.to_sparse(sparsity) # clean very small cluster before peeler - if params["clean"]["subsampling_factor"] is not None and params["clean"]["min_firing_rate"] is not None: + if params["clean_low_firing"]["subsampling_factor"] is not None and params["clean_low_firing"]["min_firing_rate"] is not None: final_peak_labels, to_keep = remove_small_cluster( recording, peaks, post_merge_label2, - min_firing_rate=params["clean"]["min_firing_rate"], - subsampling_factor=params["clean"]["subsampling_factor"], + min_firing_rate=params["clean_low_firing"]["min_firing_rate"], + subsampling_factor=params["clean_low_firing"]["subsampling_factor"], verbose=verbose, ) templates = templates.select_units(to_keep) From 31572ec4748326e30608861e0bc949fb4a650bd5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:20:55 +0000 Subject: [PATCH 33/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/iterative_hdbscan.py | 6 ++++-- .../sortingcomponents/clustering/iterative_isosplit.py | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index fffce80687..dc991d5b50 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -186,9 +186,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): is_in_uV=False, ) - # clean very small cluster before peeler - if params["clean_low_firing"]["subsampling_factor"] is not None and params["clean_low_firing"]["min_firing_rate"] is not None: + if ( + params["clean_low_firing"]["subsampling_factor"] is not None + and params["clean_low_firing"]["min_firing_rate"] is not None + ): peak_labels, to_keep = remove_small_cluster( recording, peaks, diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 625d69725c..b1d54df80c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -311,7 +311,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): templates = dense_templates.to_sparse(sparsity) # clean very small cluster before peeler - if params["clean_low_firing"]["subsampling_factor"] is not None and params["clean_low_firing"]["min_firing_rate"] is not None: + if ( + params["clean_low_firing"]["subsampling_factor"] is not None + and params["clean_low_firing"]["min_firing_rate"] is not None + ): final_peak_labels, to_keep = remove_small_cluster( recording, peaks, From c218faa864f6a76db97f7173bb60fab2f3494bde Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 11 Dec 2025 11:21:26 +0100 Subject: [PATCH 34/38] oups --- src/spikeinterface/sorters/internal/tridesclous2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index f572bcd565..cccb69ca79 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -188,7 +188,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): rec_for_motion, folder=sorter_output_folder / "motion", output_motion_info=True, - **params["motion_correction"], + preset=params["motion_correction_preset"], + # **params["motion_correction"], ) if verbose: print("Done correct_motion()") From fefce5e915c2fcab4b0bf6a290f7e063919e3e76 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 12 Dec 2025 16:51:06 +0100 Subject: [PATCH 35/38] oups --- .../sorters/internal/tridesclous2.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 08fd27cf97..2e3d1187a4 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -268,7 +268,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"select_peaks(): {len(peaks)} peaks kept for clustering") - # routing clustering params into the big IterativeISOSPLITClustering params tree + # routing clustering params into the big IterativeISOSPLITClustering params tree # clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) # clustering_kwargs["peaks_svd"].update(params["waveforms"]) # clustering_kwargs["peaks_svd"].update(params["svd"]) @@ -279,17 +279,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] # clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size - clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) - clustering_kwargs["peaks_svd"].update(params["waveforms"]) - clustering_kwargs["peaks_svd"].update(params["svd"]) - clustering_kwargs["split"].update(params["clustering"]) - if params["debug"]: - clustering_kwargs["debug_folder"] = sorter_output_folder - clustering_kwargs["noise_levels"] = noise_levels - clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] - clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size - - # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None # if not have_sisosplit6: @@ -309,8 +298,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] clustering_kwargs["noise_levels"] = noise_levels - clustering_kwargs["clean"]["min_firing_rate"] = params["min_firing_rate"] - clustering_kwargs["clean"]["subsampling_factor"] = all_peaks.size / peaks.size + clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder From 76d5b15bf803ae69cb1393d725a9f147e1bc7712 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:55:51 +0000 Subject: [PATCH 36/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/tridesclous2.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 6ea8f83d9a..c71c459983 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -82,7 +82,6 @@ class Tridesclous2Sorter(ComponentsBasedSorter): # "save_array": "Save or not intermediate arrays", # } - _default_params = { "apply_preprocessing": True, "apply_motion_correction": False, @@ -96,7 +95,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "freq_max": 6000.0, "cache_preprocessing_mode": "auto", "peak_sign": "neg", - "detect_threshold": 5., + "detect_threshold": 5.0, "n_peaks_per_channel": 5000, "n_svd_components_per_channel": 5, "n_pca_features": 6, @@ -270,7 +269,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" # ) - # Clustering + # Clustering clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] @@ -318,14 +317,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): kept_peaks, spike_vector["unit_index"], sorting_pre_peeler.unit_ids, recording, params["template_radius_um"] ) - # we recompute the template even if the clustering give it already because we use different ms_before/ms_after ms_before = params["ms_before"] ms_after = params["ms_after"] nbefore = int(ms_before * sampling_frequency / 1000.0) nafter = int(ms_after * sampling_frequency / 1000.0) - templates_array = estimate_templates_with_accumulator( recording_for_peeler, sorting_pre_peeler.to_spike_vector(), From 80dfb04d8915ff3f83d8689a0afc75e4de994ef0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 12 Dec 2025 16:57:13 +0100 Subject: [PATCH 37/38] update lupin and sc2 version --- src/spikeinterface/sorters/internal/lupin.py | 2 +- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 0b2bb6e74e..40bf32e005 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -97,7 +97,7 @@ class LupinSorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.11" + return "2025.12" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 49418a1706..01e480ea51 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -89,7 +89,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.10" + return "2025.12" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): From d5520a63fc0729c36b77f7343d2cb750b574e95d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 12 Dec 2025 17:00:20 +0100 Subject: [PATCH 38/38] yep --- .../sorters/internal/tridesclous2.py | 53 ------------------- 1 file changed, 53 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index c71c459983..5b4689cba1 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -29,59 +29,6 @@ class Tridesclous2Sorter(ComponentsBasedSorter): sorter_name = "tridesclous2" - # _default_params = { - # "apply_preprocessing": True, - # "apply_motion_correction": False, - # "motion_correction": {"preset": "dredge_fast"}, - # "cache_preprocessing_mode": "auto", - # "waveforms": { - # "ms_before": 0.5, - # "ms_after": 1.5, - # "radius_um": 120.0, - # }, - # "filtering": { - # "freq_min": 150.0, - # "freq_max": 6000.0, - # "ftype": "bessel", - # "filter_order": 2, - # }, - # "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, - # "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - # "svd": {"n_components": 5}, - # "clustering": { - # "recursive_depth": 3, - # }, - # "min_firing_rate": 0.1, - # "templates": { - # "ms_before": 2.0, - # "ms_after": 3.0, - # "max_spikes_per_unit": 400, - # "sparsity_threshold": 1.5, - # "min_snr": 3.5, - # "radius_um": 100.0, - # "max_jitter_ms": 0.2, - # }, - # "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, - # "job_kwargs": {}, - # "save_array": True, - # "debug": False, - # } - - # _params_description = { - # "apply_preprocessing": "Apply internal preprocessing or not", - # "cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ", - # "waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um", - # "filtering": "A dictonary containing filtering params: freq_min, freq_max", - # "detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um", - # "selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks", - # "svd": "A dictonary containing svd params: n_components", - # "clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um", - # "templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after", - # "matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um", - # "job_kwargs": "A dictionary containing job kwargs", - # "save_array": "Save or not intermediate arrays", - # } - _default_params = { "apply_preprocessing": True, "apply_motion_correction": False,