diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index 3885fe073c..6cf6cbe7a3 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -29,9 +29,14 @@ def __init__(self, recording, gt_sorting, params, indices, peaks, exhaustive_gt= self.method_kwargs = params["method_kwargs"] self.result = {} - def run(self, **job_kwargs): + def run(self, verbose=True, **job_kwargs): labels, peak_labels = find_clusters_from_peaks( - self.recording, self.peaks, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs + self.recording, + self.peaks, + method=self.method, + method_kwargs=self.method_kwargs, + verbose=verbose, + job_kwargs=job_kwargs, ) self.result["peak_labels"] = peak_labels diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 8115870876..02e2146b5a 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -26,9 +26,14 @@ def __init__(self, recording, gt_sorting, params): self.method_kwargs = params["method_kwargs"] self.result = {} - def run(self, **job_kwargs): + def run(self, verbose=True, **job_kwargs): spikes = find_spikes_from_templates( - self.recording, self.templates, method=self.method, method_kwargs=self.method_kwargs, job_kwargs=job_kwargs + self.recording, + self.templates, + method=self.method, + method_kwargs=self.method_kwargs, + verbose=verbose, + job_kwargs=job_kwargs, ) unit_ids = self.templates.unit_ids sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index f7fb4d705b..0b2bb6e74e 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -54,8 +54,10 @@ class LupinSorter(ComponentsBasedSorter): "clustering_recursive_depth": 3, "ms_before": 1.0, "ms_after": 2.5, - "sparsity_threshold": 1.5, - "template_min_snr": 2.5, + "template_sparsify_threshold": 1.5, + "template_min_snr_ptp": 4.0, + "template_max_jitter_ms": 0.2, + "min_firing_rate": 0.1, "gather_mode": "memory", "job_kwargs": {}, "seed": None, @@ -80,8 +82,10 @@ class LupinSorter(ComponentsBasedSorter): "clustering_recursive_depth": "Clustering recussivity", "ms_before": "Milliseconds before the spike peak for template matching", "ms_after": "Milliseconds after the spike peak for template matching", - "sparsity_threshold": "Threshold to sparsify templates before template matching", - "template_min_snr": "Threshold to remove templates before template matching", + "template_sparsify_threshold": "Threshold to sparsify templates before template matching", + "template_min_snr_ptp": "Threshold to remove templates before template matching", + "template_max_jitter_ms": "Threshold on jitters to remove templates before template matching", + "min_firing_rate": "To remove small cluster in size before template matching", "gather_mode": "How to accumalte spike in matching : memory/npy", "job_kwargs": "The famous and fabulous job_kwargs", "seed": "Seed for random number", @@ -232,6 +236,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] + clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"] + clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] + clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] + clustering_kwargs["noise_levels"] = noise_levels + clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder @@ -290,10 +300,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # this spasify more templates = clean_templates( templates, - sparsify_threshold=params["sparsity_threshold"], + sparsify_threshold=params["template_sparsify_threshold"], noise_levels=noise_levels, - min_snr=params["template_min_snr"], - max_jitter_ms=None, + min_snr=params["template_min_snr_ptp"], + max_jitter_ms=params["template_max_jitter_ms"], remove_empty=True, ) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index d4f9f39cbd..49418a1706 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -39,7 +39,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, "clustering": {"method": "iterative-hdbscan", "method_kwargs": dict()}, - "cleaning": {"min_snr": 5, "max_jitter_ms": 0.1, "sparsify_threshold": None}, + "cleaning": {"min_snr": 5, "max_jitter_ms": 0.2, "sparsify_threshold": 1, "mean_sd_ratio_threshold": 3}, + "min_firing_rate": 0.1, "matching": {"method": "circus-omp", "method_kwargs": dict(), "pipeline_kwargs": dict()}, "apply_preprocessing": True, "apply_whitening": True, @@ -103,6 +104,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering import find_clusters_from_peaks + from spikeinterface.sortingcomponents.clustering.tools import remove_small_cluster from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction from spikeinterface.sortingcomponents.tools import clean_templates @@ -118,8 +120,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_before = params["general"].get("ms_before", 0.5) ms_after = params["general"].get("ms_after", 1.5) radius_um = params["general"].get("radius_um", 100.0) - detect_threshold = params["detection"]["method_kwargs"].get("detect_threshold", 5) - peak_sign = params["detection"].get("peak_sign", "neg") deterministic = params["deterministic_peaks_detection"] debug = params["debug"] seed = params["seed"] @@ -310,6 +310,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("Kept %d peaks for clustering" % len(selected_peaks)) + cleaning_kwargs = params.get("cleaning", {}).copy() + cleaning_kwargs["remove_empty"] = True + if clustering_method in [ "iterative-hdbscan", "iterative-isosplit", @@ -319,6 +322,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(verbose=verbose) clustering_params.update(seed=seed) clustering_params.update(peaks_svd=params["general"]) + if clustering_method in ["iterative-hdbscan", "iterative-isosplit"]: + clustering_params.update(clean_templates=cleaning_kwargs) + clustering_params["noise_levels"] = noise_levels + if debug: clustering_params["debug_folder"] = sorter_output_folder / "clustering" @@ -328,6 +335,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): method=clustering_method, method_kwargs=clustering_params, extra_outputs=True, + verbose=verbose, job_kwargs=job_kwargs, ) @@ -365,7 +373,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - dense_templates, new_sparse_mask = get_templates_from_peaks_and_svd( + dense_templates, new_sparse_mask, max_std_per_channel = get_templates_from_peaks_and_svd( recording_w, selected_peaks, peak_labels, @@ -375,16 +383,30 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): more_outs["peaks_svd"], more_outs["peak_svd_sparse_mask"], operator="median", + return_max_std_per_channel=True, ) # this release the peak_svd memmap file templates = dense_templates.to_sparse(new_sparse_mask) del more_outs - cleaning_kwargs = params.get("cleaning", {}).copy() - cleaning_kwargs["noise_levels"] = noise_levels - cleaning_kwargs["remove_empty"] = True - templates = clean_templates(templates, **cleaning_kwargs) + before_clean_ids = templates.unit_ids.copy() + cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + cleaning_kwargs["verbose"] = verbose + templates = clean_templates(templates, noise_levels=noise_levels, **cleaning_kwargs) + remove_peak_mask = ~np.isin(peak_labels, templates.unit_ids) + peak_labels[remove_peak_mask] = -1 + + if params["min_firing_rate"] is not None: + peak_labels, to_keep = remove_small_cluster( + recording_w, + selected_peaks, + peak_labels, + min_firing_rate=params["min_firing_rate"], + subsampling_factor=peaks.size / selected_peaks.size, + verbose=verbose, + ) + templates = templates.select_units(to_keep) if verbose: print("Kept %d clean clusters" % len(templates.unit_ids)) @@ -508,5 +530,4 @@ def final_cleaning_circus( sparsity_overlap=sparsity_overlap, **job_kwargs, ) - return final_sa diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 2f7e84dcf4..0a7cb5f826 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -51,13 +51,15 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "clustering": { "recursive_depth": 3, }, + "min_firing_rate": 0.1, "templates": { "ms_before": 2.0, "ms_after": 3.0, "max_spikes_per_unit": 400, "sparsity_threshold": 1.5, - "min_snr": 2.5, + "min_snr": 3.5, "radius_um": 100.0, + "max_jitter_ms": 0.2, }, "matching": {"method": "tdc-peeler", "method_kwargs": {}, "gather_mode": "memory"}, "job_kwargs": {}, @@ -93,7 +95,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods - from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.preprocessing import correct_motion from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label @@ -194,6 +195,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_kwargs["split"].update(params["clustering"]) if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder + clustering_kwargs["noise_levels"] = noise_levels + clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] + clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None @@ -262,13 +266,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): is_in_uV=False, ) - # this spasify more + # this clean and spasify more templates = clean_templates( templates, sparsify_threshold=params["templates"]["sparsity_threshold"], noise_levels=noise_levels, min_snr=params["templates"]["min_snr"], - max_jitter_ms=None, + max_jitter_ms=params["templates"]["max_jitter_ms"], remove_empty=True, ) diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 0c089229ee..dc991d5b50 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -9,7 +9,9 @@ from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters -from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd +from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd, remove_small_cluster +from spikeinterface.sortingcomponents.tools import clean_templates +from spikeinterface.core.recording_tools import get_noise_levels class IterativeHDBSCANClustering: @@ -30,6 +32,7 @@ class IterativeHDBSCANClustering: _default_params = { "peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0}, "seed": None, + "noise_levels": None, "split": { "split_radius_um": 75.0, "recursive": True, @@ -43,8 +46,18 @@ class IterativeHDBSCANClustering: "n_pca_features": 3, }, }, + "clean_templates": { + "sparsify_threshold": 1.0, + "min_snr": 2.5, + "remove_empty": True, + "max_jitter_ms": 0.2, + }, "merge_from_templates": dict(similarity_thresh=0.8, num_shifts=3, use_lags=True), "merge_from_features": None, + "clean_low_firing": { + "min_firing_rate": 0.1, + "subsampling_factor": None, + }, "debug_folder": None, "verbose": True, } @@ -116,7 +129,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): **split, ) - templates, new_sparse_mask = get_templates_from_peaks_and_svd( + templates, new_sparse_mask, max_std_per_channel = get_templates_from_peaks_and_svd( recording, peaks, peak_labels, @@ -126,8 +139,27 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): peaks_svd, sparse_mask, operator="median", + return_max_std_per_channel=True, ) + ## Pre clean using templates (jitter, sparsify_threshold) + templates = templates.to_sparse(new_sparse_mask) + cleaning_kwargs = params["clean_templates"].copy() + cleaning_kwargs["verbose"] = verbose + cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + if params["noise_levels"] is not None: + noise_levels = params["noise_levels"] + else: + noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) + cleaning_kwargs["noise_levels"] = noise_levels + cleaned_templates = clean_templates(templates, **cleaning_kwargs) + mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) + to_remove_ids = templates.unit_ids[~mask_keep_ids] + to_remove_label_mask = np.isin(peak_labels, to_remove_ids) + peak_labels[to_remove_label_mask] = -1 + templates = cleaned_templates + new_sparse_mask = templates.sparsity.mask.copy() + templates = templates.to_dense() labels = templates.unit_ids if verbose: @@ -154,6 +186,21 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): is_in_uV=False, ) + # clean very small cluster before peeler + if ( + params["clean_low_firing"]["subsampling_factor"] is not None + and params["clean_low_firing"]["min_firing_rate"] is not None + ): + peak_labels, to_keep = remove_small_cluster( + recording, + peaks, + peak_labels, + min_firing_rate=params["clean_low_firing"]["min_firing_rate"], + subsampling_factor=params["clean_low_firing"]["subsampling_factor"], + verbose=verbose, + ) + templates = templates.select_units(to_keep) + labels = templates.unit_ids if debug_folder is not None: diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index 6e08548329..b1d54df80c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -10,8 +10,10 @@ merge_peak_labels_from_templates, merge_peak_labels_from_features, ) -from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd +from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd, remove_small_cluster +from spikeinterface.sortingcomponents.tools import clean_templates from spikeinterface.sortingcomponents.waveforms.peak_svd import extract_peaks_svd +from spikeinterface.core.recording_tools import get_noise_levels class IterativeISOSPLITClustering: @@ -32,6 +34,7 @@ class IterativeISOSPLITClustering: _default_params = { "motion": None, "seed": None, + "noise_levels": None, "peaks_svd": {"n_components": 5, "ms_before": 0.5, "ms_after": 1.5, "radius_um": 120.0, "motion": None}, "pre_label": { "mode": "channel", @@ -59,6 +62,12 @@ class IterativeISOSPLITClustering: # "projection_mode": "pca", }, }, + "clean_templates": { + "max_jitter_ms": 0.2, + "min_snr": 2.5, + "sparsify_threshold": 1.0, + "remove_empty": True, + }, "merge_from_templates": { "similarity_metric": "l1", "num_shifts": 3, @@ -66,8 +75,9 @@ class IterativeISOSPLITClustering: }, "merge_from_features": None, # "merge_from_features": {"merge_radius_um": 60.0}, - "clean": { - "minimum_cluster_size": 10, + "clean_low_firing": { + "min_firing_rate": 0.1, + "subsampling_factor": None, }, "debug_folder": None, "verbose": True, @@ -97,6 +107,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_before = params["peaks_svd"]["ms_before"] ms_after = params["peaks_svd"]["ms_after"] # radius_um = params["waveforms"]["radius_um"] + verbose = params["verbose"] debug_folder = params["debug_folder"] @@ -206,8 +217,49 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): operator="average", ) - unit_ids = dense_templates.unit_ids + ## Pre clean using templates (jitter, sparsify_threshold) + templates = dense_templates.to_sparse(template_sparse_mask) + cleaning_kwargs = params["clean_templates"].copy() + cleaning_kwargs["verbose"] = verbose + # cleaning_kwargs["verbose"] = True + # cleaning_kwargs["max_std_per_channel"] = max_std_per_channel + if params["noise_levels"] is not None: + noise_levels = params["noise_levels"] + else: + noise_levels = get_noise_levels(recording, return_in_uV=False, **job_kwargs) + cleaning_kwargs["noise_levels"] = noise_levels + cleaned_templates = clean_templates(templates, **cleaning_kwargs) + mask_keep_ids = np.isin(templates.unit_ids, cleaned_templates.unit_ids) + to_remove_ids = templates.unit_ids[~mask_keep_ids] + to_remove_label_mask = np.isin(post_split_label, to_remove_ids) + post_split_label[to_remove_label_mask] = -1 + template_sparse_mask = cleaned_templates.sparsity.mask.copy() + dense_templates = cleaned_templates.to_dense() templates_array = dense_templates.templates_array + unit_ids = dense_templates.unit_ids + + # ## Pre clean using templates (jitter) + # cleaned_templates = clean_templates( + # dense_templates, + # # sparsify_threshold=0.25, + # sparsify_threshold=None, + # # noise_levels=None, + # # min_snr=None, + # max_jitter_ms=params["clean_templates"]["max_jitter_ms"], + # # remove_empty=True, + # remove_empty=False, + # # sd_ratio_threshold=5.0, + # # stds_at_peak=None, + # ) + # mask_keep_ids = np.isin(dense_templates.unit_ids, cleaned_templates.unit_ids) + # to_remove_ids = dense_templates.unit_ids[~mask_keep_ids] + # to_remove_label_mask = np.isin(post_split_label, to_remove_ids) + # post_split_label[to_remove_label_mask] = -1 + # dense_templates = cleaned_templates + # template_sparse_mask = template_sparse_mask[mask_keep_ids, :] + + # unit_ids = dense_templates.unit_ids + # templates_array = dense_templates.templates_array if params["merge_from_features"] is not None: @@ -258,197 +310,26 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): sparsity = ChannelSparsity(template_sparse_mask, unit_ids, recording.channel_ids) templates = dense_templates.to_sparse(sparsity) - # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") - - # new_peaks = peaks.copy() - # new_peaks["sample_index"] -= peak_shifts - # clean very small cluster before peeler - post_clean_label = post_merge_label2.copy() - minimum_cluster_size = params["clean"]["minimum_cluster_size"] - labels_set, count = np.unique(post_clean_label, return_counts=True) - to_remove = labels_set[count < minimum_cluster_size] - mask = np.isin(post_clean_label, to_remove) - post_clean_label[mask] = -1 - final_peak_labels = post_clean_label - labels_set = np.unique(final_peak_labels) - labels_set = labels_set[labels_set >= 0] - templates = templates.select_units(labels_set) + if ( + params["clean_low_firing"]["subsampling_factor"] is not None + and params["clean_low_firing"]["min_firing_rate"] is not None + ): + final_peak_labels, to_keep = remove_small_cluster( + recording, + peaks, + post_merge_label2, + min_firing_rate=params["clean_low_firing"]["min_firing_rate"], + subsampling_factor=params["clean_low_firing"]["subsampling_factor"], + verbose=verbose, + ) + templates = templates.select_units(to_keep) + else: + final_peak_labels = post_merge_label2 + labels_set = templates.unit_ids more_outs = dict( templates=templates, ) return labels_set, final_peak_labels, more_outs - - # _default_params = { - # "clean": { - # "minimum_cluster_size": 10, - # }, - # } - - # @classmethod - # def main_function(cls, recording, peaks, params, job_kwargs=dict()): - - # split_radius_um = params["split"].pop("split_radius_um", 40) - # peaks_svd = params["peaks_svd"] - # motion = peaks_svd["motion"] - # ms_before = peaks_svd.get("ms_before", 0.5) - # ms_after = peaks_svd.get("ms_after", 1.5) - # verbose = params.get("verbose", True) - # split = params["split"] - # seed = params["seed"] - # job_kwargs = params.get("job_kwargs", dict()) - # debug_folder = params.get("debug_folder", None) - - # if debug_folder is not None: - # debug_folder = Path(debug_folder).absolute() - # debug_folder.mkdir(exist_ok=True) - # peaks_svd.update(folder=debug_folder / "features") - - # motion_aware = motion is not None - # peaks_svd.update(motion_aware=motion_aware) - - # if seed is not None: - # peaks_svd.update(seed=seed) - # split["method_kwargs"].update(seed=seed) - - # outs = extract_peaks_svd( - # recording, - # peaks, - # **peaks_svd, - # **job_kwargs, - # ) - - # if motion_aware: - # # also return peaks with new channel index - # peaks_svd, sparse_mask, svd_model, moved_peaks = outs - # peaks = moved_peaks - # else: - # peaks_svd, sparse_mask, svd_model = outs - - # if debug_folder is not None: - # np.save(debug_folder / "sparse_mask.npy", sparse_mask) - # np.save(debug_folder / "peaks.npy", peaks) - - # split["method_kwargs"].update(waveforms_sparse_mask = sparse_mask) - # neighbours_mask = get_channel_distances(recording) <= split_radius_um - # split["method_kwargs"].update(neighbours_mask = neighbours_mask) - - # if debug_folder is not None: - # split.update(debug_folder = debug_folder / "split") - - # peak_labels = split_clusters( - # peaks["channel_index"], - # recording, - # {"peaks": peaks, "sparse_tsvd": peaks_svd}, - # method="local_feature_clustering", - # **split, - # **job_kwargs, - # ) - - # templates, new_sparse_mask = get_templates_from_peaks_and_svd( - # recording, - # peaks, - # peak_labels, - # ms_before, - # ms_after, - # svd_model, - # peaks_svd, - # sparse_mask, - # operator="average", - # ) - - # labels = templates.unit_ids - - # if verbose: - # print("Kept %d raw clusters" % len(labels)) - - # if params["merge_from_features"] is not None: - - # merge_features_kwargs = params["merge_from_features"].copy() - # merge_radius_um = merge_features_kwargs.pop("merge_radius_um") - - # peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_features( - # peaks, - # peak_labels, - # templates.unit_ids, - # templates.templates_array, - # sparse_mask, - # recording, - # {"peaks": peaks, "sparse_tsvd": peaks_svd}, - # radius_um=merge_radius_um, - # method="project_distribution", - # method_kwargs=dict( - # feature_name="sparse_tsvd", - # waveforms_sparse_mask=sparse_mask, - # **merge_features_kwargs - # ), - # **job_kwargs, - # ) - - # templates = Templates( - # templates_array=merge_template_array, - # sampling_frequency=recording.sampling_frequency, - # nbefore=templates.nbefore, - # sparsity_mask=None, - # channel_ids=recording.channel_ids, - # unit_ids=new_unit_ids, - # probe=recording.get_probe(), - # is_in_uV=False, - # ) - - # if params["merge_from_templates"] is not None: - # peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates( - # peaks, - # peak_labels, - # templates.unit_ids, - # templates.templates_array, - # new_sparse_mask, - # **params["merge_from_templates"], - # ) - - # templates = Templates( - # templates_array=merge_template_array, - # sampling_frequency=recording.sampling_frequency, - # nbefore=templates.nbefore, - # sparsity_mask=None, - # channel_ids=recording.channel_ids, - # unit_ids=new_unit_ids, - # probe=recording.get_probe(), - # is_in_uV=False, - # ) - - # labels = templates.unit_ids - - # if debug_folder is not None: - # templates.to_zarr(folder_path=debug_folder / "dense_templates") - - # if verbose: - # print("Kept %d non-duplicated clusters" % len(labels)) - - # # sparsity = ChannelSparsity(template_sparse_mask, unit_ids, recording.channel_ids) - # # templates = dense_templates.to_sparse(sparsity) - - # # # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") - - # # # new_peaks = peaks.copy() - # # # new_peaks["sample_index"] -= peak_shifts - - # # # clean very small cluster before peeler - # # post_clean_label = post_merge_label2.copy() - # # minimum_cluster_size = params["clean"]["minimum_cluster_size"] - # # labels_set, count = np.unique(post_clean_label, return_counts=True) - # # to_remove = labels_set[count < minimum_cluster_size] - # # mask = np.isin(post_clean_label, to_remove) - # # post_clean_label[mask] = -1 - # # final_peak_labels = post_clean_label - # # labels_set = np.unique(final_peak_labels) - # # labels_set = labels_set[labels_set >= 0] - # # templates = templates.select_units(labels_set) - # # labels_set = templates.unit_ids - - # more_outs = dict( - # templates=templates, - # ) - # return labels, peak_labels, more_outs diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index fd932ef26c..406aee1bd4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -236,6 +236,7 @@ def get_templates_from_peaks_and_svd( svd_features, sparsity_mask, operator="average", + return_max_std_per_channel=False, ): """ Get templates from recording using the SVD components @@ -260,6 +261,8 @@ def get_templates_from_peaks_and_svd( The sparsity mask array. operator : str The operator to use for template estimation. Can be 'average' or 'median'. + return_max_std_per_channel : bool + Whether to return the max standard deviation at the channels. Returns ------- @@ -267,6 +270,8 @@ def get_templates_from_peaks_and_svd( The estimated templates object as a dense template (but internanally contain sparse channels). final_sparsity_mask: np.array The final sparsity mask. Note that the template object is dense but with zeros. + max_std_per_channel: np.array + The maximal standard deviation of the templates per channel (only if return_max_std_per_channel is True). """ assert operator in ["average", "median"], "operator should be either 'average' or 'median'" @@ -282,6 +287,9 @@ def get_templates_from_peaks_and_svd( num_channels = recording.get_num_channels() templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) + if return_max_std_per_channel: + max_std_per_channel = np.zeros((len(labels), num_channels), dtype=np.float32) + final_sparsity_mask = np.zeros((len(labels), num_channels), dtype="bool") for unit_ind, label in enumerate(labels): mask = valid_labels == label @@ -298,6 +306,11 @@ def get_templates_from_peaks_and_svd( data = np.median(local_svd[sub_mask, :, count], 0) templates_array[unit_ind, :, i] = svd_model.inverse_transform(data.reshape(1, -1)) + if return_max_std_per_channel: + data = svd_model.inverse_transform(local_svd[sub_mask, :, count]) + if len(data) > 1: + max_std_per_channel[unit_ind, i] = np.std(data, 0).max() + dense_templates = Templates( templates_array=templates_array, sampling_frequency=fs, @@ -309,4 +322,37 @@ def get_templates_from_peaks_and_svd( is_in_uV=False, ) - return dense_templates, final_sparsity_mask + if return_max_std_per_channel: + return dense_templates, final_sparsity_mask, max_std_per_channel + else: + return dense_templates, final_sparsity_mask + + +def remove_small_cluster(recording, peaks, peak_labels, min_firing_rate=0.1, subsampling_factor=None, verbose=False): + """ + Remove clusters too small in size (spike count) given a min firing rate and a subsampling factor. + """ + + if subsampling_factor is None: + if verbose: + print("remove_small_cluster(): subsampling_factor is not set, assuming 1") + subsampling_factor = 1 + + min_spike_count = int(recording.get_total_duration() * min_firing_rate / subsampling_factor) + + peak_labels = peak_labels.copy() + labels_set, count = np.unique(peak_labels, return_counts=True) + cluster_mask = count < min_spike_count + to_remove = labels_set[cluster_mask] + to_keep = labels_set[~cluster_mask] + peak_mask = np.isin(peak_labels, to_remove) + peak_labels[peak_mask] = -1 + + to_keep = to_keep[to_keep >= 0] + + if verbose: + print( + f"remove_small_cluster: kept {to_keep.size} removed {to_remove.size} (min_spike_count {min_spike_count})" + ) + + return peak_labels, to_keep diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 3f4b50be0c..92434e1b3c 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -17,9 +17,9 @@ from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeNoiseLevels -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_extremum_channel from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.sorting_tools import spike_vector_to_indices, get_numba_vector_to_list_of_spiketrain +from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain def make_multi_method_doc(methods, indent=" "): @@ -538,14 +538,25 @@ def get_shuffled_recording_slices(recording, job_kwargs=None, seed=None): def clean_templates( - templates, sparsify_threshold=0.25, noise_levels=None, min_snr=None, max_jitter_ms=None, remove_empty=True + templates, + sparsify_threshold=0.25, + noise_levels=None, + min_snr=None, + max_jitter_ms=None, + remove_empty=True, + mean_sd_ratio_threshold=3.0, + max_std_per_channel=None, + verbose=False, ): """ Clean a Templates object by removing empty units and applying sparsity if provided. """ + initial_ids = templates.unit_ids.copy() + ## First we sparsify the templates (using peak-to-peak amplitude avoid sign issues) if sparsify_threshold is not None: + assert noise_levels is not None, "noise_levels must be provided if sparsify_threshold is set" if templates.are_templates_sparse(): templates = templates.to_dense() sparsity = compute_sparsity( @@ -559,22 +570,30 @@ def clean_templates( ## We removed non empty templates if remove_empty: + n_before = len(templates.unit_ids) templates = remove_empty_templates(templates) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} empty templates") ## We keep only units with a max jitter if max_jitter_ms is not None: max_jitter = int(max_jitter_ms * templates.sampling_frequency / 1000.0) - + n_before = len(templates.unit_ids) shifts = get_template_extremum_channel_peak_shift(templates) to_select = [] for unit_id in templates.unit_ids: if np.abs(shifts[unit_id]) <= max_jitter: to_select += [unit_id] templates = templates.select_units(to_select) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} unaligned templates") ## We remove units with a low SNR if min_snr is not None: assert noise_levels is not None, "noise_levels must be provided if min_snr is set" + n_before = len(templates.unit_ids) sparsity = compute_sparsity( templates.to_dense(), method="snr", @@ -584,6 +603,25 @@ def clean_templates( ) to_select = templates.unit_ids[np.flatnonzero(sparsity.mask.sum(axis=1) > 0)] templates = templates.select_units(to_select) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} templates with too low SNR") + + ## Lastly, if stds_at_peak are provided, we remove the templates that have a too high sd_ratio + if max_std_per_channel is not None: + assert noise_levels is not None, "noise_levels must be provided if max_std_per_channel is given" + to_select = [] + n_before = len(templates.unit_ids) + for count, unit_id in enumerate(templates.unit_ids): + old_index = np.where(unit_id == initial_ids)[0][0] + mask = templates.sparsity.mask[count, :] + sd_ratio = np.mean(max_std_per_channel[old_index][mask] / noise_levels[mask]) + if sd_ratio <= mean_sd_ratio_threshold: + to_select += [unit_id] + templates = templates.select_units(to_select) + if verbose: + n_after = len(templates.unit_ids) + print(f"Removed {n_before - n_after} templates with too high mean sd / noise ratio") return templates