From 1f1c2827c8bb83e5d7b8c74bb8f1f95faa234852 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 3 Nov 2025 17:07:59 +0000 Subject: [PATCH 1/9] initial phy to sa function --- .../extractors/phykilosortextractors.py | 153 +++++++++++++++++- 1 file changed, 152 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 46a8e4cecb..ca3f160382 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -2,12 +2,15 @@ from typing import Optional from pathlib import Path +import json import numpy as np -from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python +from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python, generate_ground_truth_recording, ChannelSparsity, ComputeTemplates, create_sorting_analyzer, SortingAnalyzer from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations +from probeinterface import read_prb class BasePhyKilosortSortingExtractor(BaseSorting): """Base SortingExtractor for Phy and Kilosort output folder. @@ -302,3 +305,151 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy") read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") + + +def phy_to_analyzer(phy_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer: + """ + Function + """ + + phy_path = Path(phy_path) + + probe = read_prb(phy_path/ "probe.prb") + + sort = read_phy(phy_path) + + sampling_frequency = sort.sampling_frequency + + duration = sort._sorting_segments[0]._all_spikes[-1]/sampling_frequency + 1 + recording, _ = generate_ground_truth_recording(probe=probe.probes[0], sampling_frequency=sampling_frequency, durations=[duration]) + + sparsity = make_sparsity(sort, recording, phy_path) + + sa = create_sorting_analyzer(sort,recording, sparse=True, sparsity=sparsity) + + sa.compute("random_spikes") + + make_templates(sa, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) + make_locations(sa, phy_path) + + make_amplitudes(sa, phy_path) + + if compute_extras: + sa.compute("unit_locations") + sa.compute("correlograms") + sa.compute("template_similarity") + sa.compute("isi_histograms") + sa.compute("template_metrics", include_multi_channel_metrics=True) + sa.compute("quality_metrics") + + sa._recording = None + + return sa + + +def make_amplitudes(sa, phy_path: Path): + + amplitudes_extension = ComputeSpikeAmplitudes(sa) + + amps_np = np.load(phy_path / "amplitudes.npy") + + amplitudes_extension.data = {} + amplitudes_extension.data["amplitudes"] = amps_np + + params = { + "peak_sign": "neg" + } + amplitudes_extension.params = params + + amplitudes_extension.run_info = {'run_completed': True} + + sa.extensions['spike_amplitudes'] = amplitudes_extension + + +def make_locations(sa, phy_path): + + locations_extension = ComputeSpikeLocations(sa) + + locs_np = np.load(phy_path / "spike_positions.npy") + + num_dims = len(locs_np[0]) + column_names = ["x", "y", "z"][:num_dims] + dtype = [(name, locs_np.dtype) for name in column_names] + + structured_array = np.array(np.zeros(len(locs_np)), dtype=dtype) + for a, column_name in enumerate(column_names): + structured_array[column_name] = locs_np[:,a] + + locations_extension.data = {} + locations_extension.data["spike_locations"] = structured_array + + params = {} + locations_extension.params = params + + locations_extension.run_info = {'run_completed': True} + + sa.extensions['spike_locations'] = locations_extension + + +def make_sparsity(sort, rec, phy_path): + + templates = np.load(phy_path / "templates.npy") + + unit_ids = sort.unit_ids + channel_ids = rec.channel_ids + + # The raw templates have dense dimensions (num chan)x(num units) + # but are zero on many channels, which implicitly defines the sparsity + mask = np.sum(np.abs(templates), axis=1) != 0 + return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) + + +def make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True): + + template_extension = ComputeTemplates(sa) + + whitened_templates = np.load(phy_path / "templates.npy") + wh_inv = np.load(phy_path / "whitening_mat_inv.npy") + new_templates = compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates + + template_extension.data = {'average': new_templates} + + ops_path = phy_path / "ops.npy" + if ops_path.is_file(): + ops = np.load(ops_path, allow_pickle=True) + + samples_before = ops.item(0).get('nt0min') + nt = ops.item(0).get('nt') + + samples_after = nt - samples_before + + ms_before = samples_before/(sampling_frequency//1000) + ms_after = samples_after/(sampling_frequency//1000) + + + params = { + "operators": ["average"], + "ms_before": ms_before, + "ms_after": ms_after, + "peak_sign": "pos", + } + + template_extension.params = params + template_extension.run_info = {'run_completed': True} + + sa.extensions['templates'] = template_extension + + +def compute_unwhitened_templates(whitened_templates, wh_inv, mask): + + template_shape = np.shape(whitened_templates) + new_templates = np.zeros(template_shape) + + sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask] + + for a, unit_sparsity in enumerate(sparsity_channel_ids): + for b in unit_sparsity: + for c in unit_sparsity: + new_templates[a,:,b] += wh_inv[b,c]*whitened_templates[a,:,c] + + return new_templates From 11624289e1dfddded4cddbcb62e0077a25aa9f46 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 6 Nov 2025 16:44:33 +0000 Subject: [PATCH 2/9] add ks to sa converter --- .../extractors/phykilosortextractors.py | 146 +++++++++++------- 1 file changed, 91 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index ca3f160382..a8964b8090 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -6,11 +6,21 @@ import numpy as np -from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python, generate_ground_truth_recording, ChannelSparsity, ComputeTemplates, create_sorting_analyzer, SortingAnalyzer +from spikeinterface.core import ( + BaseSorting, + BaseSortingSegment, + read_python, + generate_ground_truth_recording, + ChannelSparsity, + ComputeTemplates, + create_sorting_analyzer, + SortingAnalyzer, +) from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations -from probeinterface import read_prb +from probeinterface import read_prb, Probe + class BasePhyKilosortSortingExtractor(BaseSorting): """Base SortingExtractor for Phy and Kilosort output folder. @@ -307,47 +317,76 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def phy_to_analyzer(phy_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer: - """ - Function +def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer: """ + Load kilosort output into a SortingAnalyzer. + Output from kilosort version 4.1 and above are supported. - phy_path = Path(phy_path) - - probe = read_prb(phy_path/ "probe.prb") - - sort = read_phy(phy_path) - - sampling_frequency = sort.sampling_frequency + Parameters + ---------- + folder_path : str or Path + Path to the output Phy folder (containing the params.py). + compute_extras : bool, default: False + Compute the extra extensions: unit_locations, correlograms, template_similarity, isi_histograms, template_metrics, quality_metrics. + unwhiten : bool, default: True + Unwhiten the templates computed by kilosort. - duration = sort._sorting_segments[0]._all_spikes[-1]/sampling_frequency + 1 - recording, _ = generate_ground_truth_recording(probe=probe.probes[0], sampling_frequency=sampling_frequency, durations=[duration]) + Returns + ------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + """ - sparsity = make_sparsity(sort, recording, phy_path) + phy_path = Path(folder_path) + + sorting = read_phy(phy_path) + sampling_frequency = sorting.sampling_frequency + duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 + + if (phy_path / "probe.prb").is_file(): + probegroup = read_prb(phy_path / "probe.prb") + probe = probegroup.probes[0] + elif (phy_path / "channel_positions.npy").is_file(): + probe = Probe(si_units="um") + channel_positions = np.load(phy_path / "channel_positions.npy") + probe.set_contacts(channel_positions) + probe.set_device_channel_indices(range(probe.get_contact_count())) + else: + AssertionError("Cannot read probe layout from folder {phy_path}.") + + # to make the initial analyzer, we'll use a fake recording and set it to None later + recording, _ = generate_ground_truth_recording( + probe=probe, sampling_frequency=sampling_frequency, durations=[duration] + ) - sa = create_sorting_analyzer(sort,recording, sparse=True, sparsity=sparsity) + sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) - sa.compute("random_spikes") + sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) - make_templates(sa, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) - make_locations(sa, phy_path) + # first compute random spikes. These do nothing, but are needed for si-gui to run + sorting_analyzer.compute("random_spikes") - make_amplitudes(sa, phy_path) + _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) + _make_locations(sorting_analyzer, phy_path) + _make_amplitudes(sorting_analyzer, phy_path) if compute_extras: - sa.compute("unit_locations") - sa.compute("correlograms") - sa.compute("template_similarity") - sa.compute("isi_histograms") - sa.compute("template_metrics", include_multi_channel_metrics=True) - sa.compute("quality_metrics") - - sa._recording = None + sorting_analyzer.compute( + { + "unit_locations": {}, + "correlograms": {}, + "template_similarity": {}, + "isi_histograms": {}, + "template_metrics": {"include_multi_channel_metrics": True}, + "quality_metrics": {}, + } + ) - return sa + sorting_analyzer._recording = None + return sorting_analyzer -def make_amplitudes(sa, phy_path: Path): +def _make_amplitudes(sa, phy_path: Path): amplitudes_extension = ComputeSpikeAmplitudes(sa) @@ -356,17 +395,15 @@ def make_amplitudes(sa, phy_path: Path): amplitudes_extension.data = {} amplitudes_extension.data["amplitudes"] = amps_np - params = { - "peak_sign": "neg" - } + params = {"peak_sign": "neg"} amplitudes_extension.params = params - amplitudes_extension.run_info = {'run_completed': True} + amplitudes_extension.run_info = {"run_completed": True} - sa.extensions['spike_amplitudes'] = amplitudes_extension + sa.extensions["spike_amplitudes"] = amplitudes_extension -def make_locations(sa, phy_path): +def _make_locations(sa, phy_path): locations_extension = ComputeSpikeLocations(sa) @@ -378,7 +415,7 @@ def make_locations(sa, phy_path): structured_array = np.array(np.zeros(len(locs_np)), dtype=dtype) for a, column_name in enumerate(column_names): - structured_array[column_name] = locs_np[:,a] + structured_array[column_name] = locs_np[:, a] locations_extension.data = {} locations_extension.data["spike_locations"] = structured_array @@ -386,12 +423,12 @@ def make_locations(sa, phy_path): params = {} locations_extension.params = params - locations_extension.run_info = {'run_completed': True} + locations_extension.run_info = {"run_completed": True} - sa.extensions['spike_locations'] = locations_extension + sa.extensions["spike_locations"] = locations_extension -def make_sparsity(sort, rec, phy_path): +def _make_sparsity_from_templates(sort, rec, phy_path): templates = np.load(phy_path / "templates.npy") @@ -404,43 +441,42 @@ def make_sparsity(sort, rec, phy_path): return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) -def make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True): +def _make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True): template_extension = ComputeTemplates(sa) whitened_templates = np.load(phy_path / "templates.npy") wh_inv = np.load(phy_path / "whitening_mat_inv.npy") - new_templates = compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates + new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates - template_extension.data = {'average': new_templates} + template_extension.data = {"average": new_templates} ops_path = phy_path / "ops.npy" if ops_path.is_file(): ops = np.load(ops_path, allow_pickle=True) - - samples_before = ops.item(0).get('nt0min') - nt = ops.item(0).get('nt') - samples_after = nt - samples_before + samples_before = ops.item(0).get("nt0min") + nt = ops.item(0).get("nt") - ms_before = samples_before/(sampling_frequency//1000) - ms_after = samples_after/(sampling_frequency//1000) + samples_after = nt - samples_before + ms_before = samples_before / (sampling_frequency // 1000) + ms_after = samples_after / (sampling_frequency // 1000) params = { "operators": ["average"], "ms_before": ms_before, "ms_after": ms_after, - "peak_sign": "pos", + "peak_sign": "neg", } template_extension.params = params - template_extension.run_info = {'run_completed': True} + template_extension.run_info = {"run_completed": True} + + sa.extensions["templates"] = template_extension - sa.extensions['templates'] = template_extension - -def compute_unwhitened_templates(whitened_templates, wh_inv, mask): +def _compute_unwhitened_templates(whitened_templates, wh_inv, mask): template_shape = np.shape(whitened_templates) new_templates = np.zeros(template_shape) @@ -450,6 +486,6 @@ def compute_unwhitened_templates(whitened_templates, wh_inv, mask): for a, unit_sparsity in enumerate(sparsity_channel_ids): for b in unit_sparsity: for c in unit_sparsity: - new_templates[a,:,b] += wh_inv[b,c]*whitened_templates[a,:,c] + new_templates[a, :, b] += wh_inv[b, c] * whitened_templates[a, :, c] return new_templates From 8105912645b020e1ba50fb7befed3bd440b150e5 Mon Sep 17 00:00:00 2001 From: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:03:52 +0000 Subject: [PATCH 3/9] Update src/spikeinterface/extractors/phykilosortextractors.py Co-authored-by: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> --- src/spikeinterface/extractors/phykilosortextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index a8964b8090..2739f46597 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -352,7 +352,7 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True probe.set_contacts(channel_positions) probe.set_device_channel_indices(range(probe.get_contact_count())) else: - AssertionError("Cannot read probe layout from folder {phy_path}.") + AssertionError(f"Cannot read probe layout from folder {phy_path}.") # to make the initial analyzer, we'll use a fake recording and set it to None later recording, _ = generate_ground_truth_recording( From 39aa298919f75171e55f2d0424a1d995f2f44093 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 10 Nov 2025 17:09:32 +0000 Subject: [PATCH 4/9] easy responses to Joe --- .../extractors/phykilosortextractors.py | 92 ++++++++++--------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 2739f46597..000811b8d1 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -345,6 +345,8 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True if (phy_path / "probe.prb").is_file(): probegroup = read_prb(phy_path / "probe.prb") + if len(probegroup.probes) > 0: + raise ValueError("Found more than one probe. Multiple probes are not currently supported.") probe = probegroup.probes[0] elif (phy_path / "channel_positions.npy").is_file(): probe = Probe(si_units="um") @@ -386,54 +388,52 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True return sorting_analyzer -def _make_amplitudes(sa, phy_path: Path): +def _make_amplitudes(sorting_analyzer, kilosort_output_path): + """Constructs a `spike_amplitudes` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" - amplitudes_extension = ComputeSpikeAmplitudes(sa) + amplitudes_extension = ComputeSpikeAmplitudes(sorting_analyzer) - amps_np = np.load(phy_path / "amplitudes.npy") - - amplitudes_extension.data = {} - amplitudes_extension.data["amplitudes"] = amps_np - - params = {"peak_sign": "neg"} - amplitudes_extension.params = params + amps_np = np.load(kilosort_output_path / "amplitudes.npy") + amplitudes_extension.data = {"amplitudes": amps_np} + amplitudes_extension.params = {"peak_sign": "neg"} amplitudes_extension.run_info = {"run_completed": True} - sa.extensions["spike_amplitudes"] = amplitudes_extension + sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension -def _make_locations(sa, phy_path): +def _make_locations(sorting_analyzer, kilosort_output_path): + """Constructs a `spike_locations` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" - locations_extension = ComputeSpikeLocations(sa) + locations_extension = ComputeSpikeLocations(sorting_analyzer) - locs_np = np.load(phy_path / "spike_positions.npy") + locs_np = np.load(kilosort_output_path / "spike_positions.npy") num_dims = len(locs_np[0]) column_names = ["x", "y", "z"][:num_dims] dtype = [(name, locs_np.dtype) for name in column_names] - structured_array = np.array(np.zeros(len(locs_np)), dtype=dtype) - for a, column_name in enumerate(column_names): - structured_array[column_name] = locs_np[:, a] - - locations_extension.data = {} - locations_extension.data["spike_locations"] = structured_array - - params = {} - locations_extension.params = params + structured_array = np.zeros(len(locs_np), dtype=dtype) + for coordinate_index, column_name in enumerate(column_names): + structured_array[column_name] = locs_np[:, coordinate_index] + locations_extension.data = {"spike_locations": structured_array} + locations_extension.params = {} locations_extension.run_info = {"run_completed": True} - sa.extensions["spike_locations"] = locations_extension + sorting_analyzer.extensions["spike_locations"] = locations_extension -def _make_sparsity_from_templates(sort, rec, phy_path): +def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): + """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the + templates output is zero or not on all channels.""" - templates = np.load(phy_path / "templates.npy") + templates = np.load(kilosort_output_path / "templates.npy") - unit_ids = sort.unit_ids - channel_ids = rec.channel_ids + unit_ids = sorting.unit_ids + channel_ids = recording.channel_ids # The raw templates have dense dimensions (num chan)x(num units) # but are zero on many channels, which implicitly defines the sparsity @@ -441,27 +441,29 @@ def _make_sparsity_from_templates(sort, rec, phy_path): return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) -def _make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True): +def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True): + """Constructs a `templates` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" - template_extension = ComputeTemplates(sa) + template_extension = ComputeTemplates(sorting_analyzer) - whitened_templates = np.load(phy_path / "templates.npy") - wh_inv = np.load(phy_path / "whitening_mat_inv.npy") + whitened_templates = np.load(kilosort_output_path / "templates.npy") + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates template_extension.data = {"average": new_templates} - ops_path = phy_path / "ops.npy" + ops_path = kilosort_output_path / "ops.npy" if ops_path.is_file(): ops = np.load(ops_path, allow_pickle=True) - samples_before = ops.item(0).get("nt0min") - nt = ops.item(0).get("nt") + number_samples_before_template_peak = ops.item(0)["nt0min"] + total_template_samples = ops.item(0)["nt"] - samples_after = nt - samples_before + number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak - ms_before = samples_before / (sampling_frequency // 1000) - ms_after = samples_after / (sampling_frequency // 1000) + ms_before = number_samples_before_template_peak / (sampling_frequency // 1000) + ms_after = number_samples_after_template_peak / (sampling_frequency // 1000) params = { "operators": ["average"], @@ -473,19 +475,25 @@ def _make_templates(sa, phy_path, mask, sampling_frequency, unwhiten=True): template_extension.params = params template_extension.run_info = {"run_completed": True} - sa.extensions["templates"] = template_extension + sorting_analyzer.extensions["templates"] = template_extension def _compute_unwhitened_templates(whitened_templates, wh_inv, mask): + """Constructs unwhitened templates from whitened_templates, by + applying an inverse whitening matrix.""" template_shape = np.shape(whitened_templates) new_templates = np.zeros(template_shape) sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask] - for a, unit_sparsity in enumerate(sparsity_channel_ids): - for b in unit_sparsity: - for c in unit_sparsity: - new_templates[a, :, b] += wh_inv[b, c] * whitened_templates[a, :, c] + for unit_index, channel_indices in enumerate(sparsity_channel_ids): + for channel_index_1 in channel_indices: + for channel_index_2 in channel_indices: + # templates have dimension unit_index x sample_index x channel_index + # to undo whitening, we need do matrix multiplication on the channel_index + new_templates[unit_index, :, channel_index_1] += ( + wh_inv[channel_index_1, channel_index_2] * whitened_templates[unit_index, :, channel_index_2] + ) return new_templates From 10ed2c4ad3fc89d97289f85d993304b5cc3c69ab Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 20 Nov 2025 11:53:50 +0000 Subject: [PATCH 5/9] respond to joe, add guess version number --- .../extractors/phykilosortextractors.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 000811b8d1..3a2713a828 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -2,7 +2,7 @@ from typing import Optional from pathlib import Path -import json +import warnings import numpy as np @@ -326,8 +326,6 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True ---------- folder_path : str or Path Path to the output Phy folder (containing the params.py). - compute_extras : bool, default: False - Compute the extra extensions: unit_locations, correlograms, template_similarity, isi_histograms, template_metrics, quality_metrics. unwhiten : bool, default: True Unwhiten the templates computed by kilosort. @@ -339,14 +337,19 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True phy_path = Path(folder_path) + guessed_kilosort_version = _guess_kilosort_version(phy_path) + sorting = read_phy(phy_path) sampling_frequency = sorting.sampling_frequency + + # kilosort occasionally contains a few spikes beyond the recording end point, which can lead + # to errors later. To avoid this, we pad the recording with an extra second of blank time. duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 if (phy_path / "probe.prb").is_file(): probegroup = read_prb(phy_path / "probe.prb") if len(probegroup.probes) > 0: - raise ValueError("Found more than one probe. Multiple probes are not currently supported.") + warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.") probe = probegroup.probes[0] elif (phy_path / "channel_positions.npy").is_file(): probe = Probe(si_units="um") @@ -372,22 +375,29 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True _make_locations(sorting_analyzer, phy_path) _make_amplitudes(sorting_analyzer, phy_path) - if compute_extras: - sorting_analyzer.compute( - { - "unit_locations": {}, - "correlograms": {}, - "template_similarity": {}, - "isi_histograms": {}, - "template_metrics": {"include_multi_channel_metrics": True}, - "quality_metrics": {}, - } - ) - sorting_analyzer._recording = None return sorting_analyzer +def _guess_kilosort_version(kilosort_path) -> tuple: + """ + Guesses the kilosort version based on the files which exist in folder `kilosort_path`. + If unknown, returns minimum guessed version. + + Returns + ------- + version_number : tuple + Version number in the form (major, minor, patch) + """ + + kilosort_log_file = Path(kilosort_path / "kilosort4.log") + + if kilosort_log_file.is_file(): + return (4, 0, 33) + else: + return (2, 0, 0) + + def _make_amplitudes(sorting_analyzer, kilosort_output_path): """Constructs a `spike_amplitudes` extension from the amplitudes numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" From 8fe91b71e9242d9202ddd8babfab6bd17d661b96 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 20 Nov 2025 13:57:06 +0000 Subject: [PATCH 6/9] add docs --- doc/how_to/import_kilosort_data.rst | 49 +++++++++++++++++++++++++++++ doc/how_to/index.rst | 1 + 2 files changed, 50 insertions(+) create mode 100644 doc/how_to/import_kilosort_data.rst diff --git a/doc/how_to/import_kilosort_data.rst b/doc/how_to/import_kilosort_data.rst new file mode 100644 index 0000000000..1143f71958 --- /dev/null +++ b/doc/how_to/import_kilosort_data.rst @@ -0,0 +1,49 @@ +Import Kilosort4 output +======================= + +If you have sorted your data with `Kilosort4 `__, your sorter output is saved in format which was +designed to be compatible with `phy `__. SpikeInterface provides a function which can be used to +transform this output into a ``SortingAnalyzer``. This is helpful if you'd like to compute some more properties of your sorting +(e.g. quality and template metrics), or if you'd like to visualize your output using `spikeinterface-gui `__. + +To create an analyzer from a Kilosort4 output folder, simply run + +.. code:: + + from spikeinterface.extractors import kilosort_output_to_analyzer + sorting_analyzer = kilosort_output_to_analyzer('path/to/output') + +The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to +save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created +when you ran ``run_sorter``. + +The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain +information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``. +You can compute extra information about the sorting using the ``compute`` method. For example, + +.. code:: + + sorting_analyzer.compute({ + "unit_locations": {}, + "correlograms": {}, + "template_similarity": {}, + "isi_histograms": {}, + "template_metrics": {include_multi_channel_metrics: True}, + "quality_metrics": {}, + }) + +widgets.html#available-plotting-functions + +Learn more about the ``SortingAnalyzer`` and its ``extensions`` `here `__. + +If you'd like to store the information you've computed, you can save the analyzer: + +.. code:: + + sorting_analyzer.save_as( + format="binary_folder", + folder="my_kilosort_analyzer" + ) + +You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui `__. to view the results +interactively, or start manually labelling your units to `create an automated curation model `__. diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 2d490c207d..db02da5cef 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -22,3 +22,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. benchmark_with_hybrid_recordings auto_curation_training auto_curation_prediction + import_kilosort_data From 06598915fd99f83484022ca8a58ff6085bf0d6c7 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 4 Dec 2025 16:03:25 +0000 Subject: [PATCH 7/9] respond to sam and add stuff for ks2,3 --- src/spikeinterface/extractors/__init__.py | 3 +- .../extractors/phykilosortextractors.py | 67 +++++++++---------- 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/src/spikeinterface/extractors/__init__.py b/src/spikeinterface/extractors/__init__.py index 604dee68f9..216c668e0e 100644 --- a/src/spikeinterface/extractors/__init__.py +++ b/src/spikeinterface/extractors/__init__.py @@ -3,10 +3,9 @@ from .toy_example import toy_example as toy_example from .bids import read_bids as read_bids - from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts - from .neoextractors import get_neo_num_blocks, get_neo_streams +from .phykilosortextractors import read_kilosort_as_analyzer from warnings import warn diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 3a2713a828..f0741a2164 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -317,7 +317,7 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer: +def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: """ Load kilosort output into a SortingAnalyzer. Output from kilosort version 4.1 and above are supported. @@ -361,7 +361,11 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True # to make the initial analyzer, we'll use a fake recording and set it to None later recording, _ = generate_ground_truth_recording( - probe=probe, sampling_frequency=sampling_frequency, durations=[duration] + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + num_units=1, + seed=1205, ) sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) @@ -373,7 +377,6 @@ def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) _make_locations(sorting_analyzer, phy_path) - _make_amplitudes(sorting_analyzer, phy_path) sorting_analyzer._recording = None return sorting_analyzer @@ -398,28 +401,17 @@ def _guess_kilosort_version(kilosort_path) -> tuple: return (2, 0, 0) -def _make_amplitudes(sorting_analyzer, kilosort_output_path): - """Constructs a `spike_amplitudes` extension from the amplitudes numpy array - in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" - - amplitudes_extension = ComputeSpikeAmplitudes(sorting_analyzer) - - amps_np = np.load(kilosort_output_path / "amplitudes.npy") - - amplitudes_extension.data = {"amplitudes": amps_np} - amplitudes_extension.params = {"peak_sign": "neg"} - amplitudes_extension.run_info = {"run_completed": True} - - sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension - - def _make_locations(sorting_analyzer, kilosort_output_path): """Constructs a `spike_locations` extension from the amplitudes numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" locations_extension = ComputeSpikeLocations(sorting_analyzer) - locs_np = np.load(kilosort_output_path / "spike_positions.npy") + spike_locations_path = kilosort_output_path / "spike_positions.npy" + if spike_locations_path.is_file(): + locs_np = np.load(spike_locations_path) + else: + return num_dims = len(locs_np[0]) column_names = ["x", "y", "z"][:num_dims] @@ -445,7 +437,7 @@ def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): unit_ids = sorting.unit_ids channel_ids = recording.channel_ids - # The raw templates have dense dimensions (num chan)x(num units) + # The raw templates have dense dimensions (num chan)x(num samples)x(num units) # but are zero on many channels, which implicitly defines the sparsity mask = np.sum(np.abs(templates), axis=1) != 0 return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) @@ -459,7 +451,7 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ whitened_templates = np.load(kilosort_output_path / "templates.npy") wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") - new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates + new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates template_extension.data = {"average": new_templates} @@ -475,11 +467,21 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ ms_before = number_samples_before_template_peak / (sampling_frequency // 1000) ms_after = number_samples_after_template_peak / (sampling_frequency // 1000) + # Used for kilosort 2, 2.5 and 3 + else: + + warnings.warn("Can't extract `ms_before` and `ms_after` from Kilosort output. Guessing a sensible value.") + + samples_in_templates = np.shape(new_templates)[1] + template_extent_ms = (samples_in_templates + 1) / (sampling_frequency // 1000) + ms_before = template_extent_ms / 3 + ms_after = 2 * template_extent_ms / 3 + params = { "operators": ["average"], "ms_before": ms_before, "ms_after": ms_after, - "peak_sign": "neg", + "peak_sign": "both", } template_extension.params = params @@ -488,22 +490,13 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ sorting_analyzer.extensions["templates"] = template_extension -def _compute_unwhitened_templates(whitened_templates, wh_inv, mask): +def _compute_unwhitened_templates(whitened_templates, wh_inv): """Constructs unwhitened templates from whitened_templates, by applying an inverse whitening matrix.""" - template_shape = np.shape(whitened_templates) - new_templates = np.zeros(template_shape) - - sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask] - - for unit_index, channel_indices in enumerate(sparsity_channel_ids): - for channel_index_1 in channel_indices: - for channel_index_2 in channel_indices: - # templates have dimension unit_index x sample_index x channel_index - # to undo whitening, we need do matrix multiplication on the channel_index - new_templates[unit_index, :, channel_index_1] += ( - wh_inv[channel_index_1, channel_index_2] * whitened_templates[unit_index, :, channel_index_2] - ) + # templates have dimension (num units) x (num samples) x (num channels) + # whitening inverse has dimension (num units) x (num channels) + # to undo whitening, we need do matrix multiplication on the channel index + unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) - return new_templates + return unwhitened_templates From 0f1a2c578f1d4ead1d957005180a9fe7c349711c Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 4 Dec 2025 16:20:34 +0000 Subject: [PATCH 8/9] try adding a test to ks4 tests --- .github/scripts/test_kilosort4_ci.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 47bbd1f4d1..1a6a2edc32 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -30,6 +30,7 @@ from spikeinterface.core.testing import check_sortings_equal from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter from probeinterface.io import write_prb +from spikeinterface.extractors import read_kilosort_as_analyzer import kilosort from kilosort.parameters import DEFAULT_SETTINGS @@ -396,6 +397,9 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp with pytest.raises(AssertionError): check_sortings_equal(default_kilosort_sorting, sorting_si) + # Check that the kilosort -> analyzer tool works + analyzer = read_kilosort_as_analyzer(kilosort_output_dir) + def test_clear_cache(self,recording_and_paths, tmp_path): """ Test clear_cache parameter in kilosort4.run_kilosort From 369f3eeba8d1f3919df3fdcaee7ac8c693cffe6e Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 9 Dec 2025 09:51:44 +0000 Subject: [PATCH 9/9] check spike length, update docs, remove version checker --- .github/scripts/test_kilosort4_ci.py | 6 +-- doc/how_to/import_kilosort_data.rst | 38 +++++++++++++++++-- .../extractors/phykilosortextractors.py | 37 +++++++----------- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 1a6a2edc32..689f48b39e 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -391,13 +391,13 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp ops = ops.tolist() # strangely this makes a dict assert ops[param_key] == param_value - # Finally, check out test parameters actually change the output of - # KS4, ensuring our tests are actually doing something (exxcept for some params). + # Check our test parameters actually change the output of + # KS4, ensuring our tests are actually doing something (except for some params). if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS: with pytest.raises(AssertionError): check_sortings_equal(default_kilosort_sorting, sorting_si) - # Check that the kilosort -> analyzer tool works + # Check that the kilosort -> analyzer tool doesn't error analyzer = read_kilosort_as_analyzer(kilosort_output_dir) def test_clear_cache(self,recording_and_paths, tmp_path): diff --git a/doc/how_to/import_kilosort_data.rst b/doc/how_to/import_kilosort_data.rst index 1143f71958..dad522334a 100644 --- a/doc/how_to/import_kilosort_data.rst +++ b/doc/how_to/import_kilosort_data.rst @@ -10,13 +10,16 @@ To create an analyzer from a Kilosort4 output folder, simply run .. code:: - from spikeinterface.extractors import kilosort_output_to_analyzer - sorting_analyzer = kilosort_output_to_analyzer('path/to/output') + from spikeinterface.extractors import read_kilosort_as_analyzer + sorting_analyzer = read_kilosort_as_analyzer('path/to/output') The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created when you ran ``run_sorter``. +Note: the function ``read_kilosort_as_analyzer`` might work on older versions of Kilosort such as Kilosort2 and Kilosort3. +However, we do not guarantee that the results are correct. + The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``. You can compute extra information about the sorting using the ``compute`` method. For example, @@ -28,7 +31,7 @@ You can compute extra information about the sorting using the ``compute`` method "correlograms": {}, "template_similarity": {}, "isi_histograms": {}, - "template_metrics": {include_multi_channel_metrics: True}, + "template_metrics": {"include_multi_channel_metrics": True}, "quality_metrics": {}, }) @@ -47,3 +50,32 @@ If you'd like to store the information you've computed, you can save the analyze You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui `__. to view the results interactively, or start manually labelling your units to `create an automated curation model `__. + +Note that if you have access to the raw recording, you can attach it to the analyzer, and re-compute extensions from the raw data. E.g. + +.. code:: + + from spikeinterface.extractors import read_kilosort_as_analyzer + import spikeinterface.extractors as se + import spikeinterface.extractors as spre + + recording = se.read_openephys('path/to/recording') + + preprocessed_recording = spre.bandpass_filter(spre.common_reference(recording)) + + sorting_analyzer = read_kilosort_as_analyzer('path/to/output') + sorting_analyzer.set_temporary_recording(preprocessed_recording) + + sorting_analyzer.compute({ + "spike_locations": {}, + "spike_amplitudes": {}, + "unit_locations": {}, + "correlograms": {}, + "template_similarity": {}, + "isi_histograms": {}, + "template_metrics": {"include_multi_channel_metrics": True}, + "quality_metrics": {}, + }) + + +This will take longer since you are dealing with the raw recording, but you do have a lot of control over how to compute the extensions. diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index f0741a2164..68b16074fb 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -319,8 +319,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: """ - Load kilosort output into a SortingAnalyzer. - Output from kilosort version 4.1 and above are supported. + Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and + above are supported. The function may work on older versions of Kilosort output, + but these are not carefully tested. Please check your output carefully. Parameters ---------- @@ -337,12 +338,10 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: phy_path = Path(folder_path) - guessed_kilosort_version = _guess_kilosort_version(phy_path) - sorting = read_phy(phy_path) sampling_frequency = sorting.sampling_frequency - # kilosort occasionally contains a few spikes beyond the recording end point, which can lead + # kilosort occasionally contains a few spikes just beyond the recording end point, which can lead # to errors later. To avoid this, we pad the recording with an extra second of blank time. duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 @@ -382,25 +381,6 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: return sorting_analyzer -def _guess_kilosort_version(kilosort_path) -> tuple: - """ - Guesses the kilosort version based on the files which exist in folder `kilosort_path`. - If unknown, returns minimum guessed version. - - Returns - ------- - version_number : tuple - Version number in the form (major, minor, patch) - """ - - kilosort_log_file = Path(kilosort_path / "kilosort4.log") - - if kilosort_log_file.is_file(): - return (4, 0, 33) - else: - return (2, 0, 0) - - def _make_locations(sorting_analyzer, kilosort_output_path): """Constructs a `spike_locations` extension from the amplitudes numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" @@ -413,6 +393,15 @@ def _make_locations(sorting_analyzer, kilosort_output_path): else: return + # Check that the spike locations vector is the same size as the spike vector + num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) + num_spike_locs = len(locs_np) + if num_spikes != num_spike_locs: + warnings.warn( + "The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations." + ) + return + num_dims = len(locs_np[0]) column_names = ["x", "y", "z"][:num_dims] dtype = [(name, locs_np.dtype) for name in column_names]