Skip to content
8 changes: 6 additions & 2 deletions .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from spikeinterface.core.testing import check_sortings_equal
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter
from probeinterface.io import write_prb
from spikeinterface.extractors import read_kilosort_as_analyzer

import kilosort
from kilosort.parameters import DEFAULT_SETTINGS
Expand Down Expand Up @@ -405,12 +406,15 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp
ops = ops.tolist() # strangely this makes a dict
assert ops[param_key] == param_value

# Finally, check out test parameters actually change the output of
# KS4, ensuring our tests are actually doing something (exxcept for some params).
# Check our test parameters actually change the output of
# KS4, ensuring our tests are actually doing something (except for some params).
if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS:
with pytest.raises(AssertionError):
check_sortings_equal(default_kilosort_sorting, sorting_si)

# Check that the kilosort -> analyzer tool doesn't error
analyzer = read_kilosort_as_analyzer(kilosort_output_dir)

def test_clear_cache(self,recording_and_paths, tmp_path):
"""
Test clear_cache parameter in kilosort4.run_kilosort
Expand Down
81 changes: 81 additions & 0 deletions doc/how_to/import_kilosort_data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
Import Kilosort4 output
=======================

If you have sorted your data with `Kilosort4 <https://github.com/MouseLand/Kilosort>`__, your sorter output is saved in format which was
designed to be compatible with `phy <https://github.com/cortex-lab/phy>`__. SpikeInterface provides a function which can be used to
transform this output into a ``SortingAnalyzer``. This is helpful if you'd like to compute some more properties of your sorting
(e.g. quality and template metrics), or if you'd like to visualize your output using `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__.

To create an analyzer from a Kilosort4 output folder, simply run

.. code::

from spikeinterface.extractors import read_kilosort_as_analyzer
sorting_analyzer = read_kilosort_as_analyzer('path/to/output')

The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to
save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created
when you ran ``run_sorter``.

Note: the function ``read_kilosort_as_analyzer`` might work on older versions of Kilosort such as Kilosort2 and Kilosort3.
However, we do not guarantee that the results are correct.

The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain
information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``.
You can compute extra information about the sorting using the ``compute`` method. For example,

.. code::

sorting_analyzer.compute({
"unit_locations": {},
"correlograms": {},
"template_similarity": {},
"isi_histograms": {},
"template_metrics": {"include_multi_channel_metrics": True},
"quality_metrics": {},
})

widgets.html#available-plotting-functions

Learn more about the ``SortingAnalyzer`` and its ``extensions`` `here <https://spikeinterface.readthedocs.io/en/stable/modules/postprocessing.html>`__.

If you'd like to store the information you've computed, you can save the analyzer:

.. code::

sorting_analyzer.save_as(
format="binary_folder",
folder="my_kilosort_analyzer"
)

You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__. to view the results
interactively, or start manually labelling your units to `create an automated curation model <https://spikeinterface.readthedocs.io/en/stable/tutorials_custom_index.html#automated-curation-tutorials>`__.

Note that if you have access to the raw recording, you can attach it to the analyzer, and re-compute extensions from the raw data. E.g.

.. code::

from spikeinterface.extractors import read_kilosort_as_analyzer
import spikeinterface.extractors as se
import spikeinterface.extractors as spre

recording = se.read_openephys('path/to/recording')

preprocessed_recording = spre.bandpass_filter(spre.common_reference(recording))

sorting_analyzer = read_kilosort_as_analyzer('path/to/output')
sorting_analyzer.set_temporary_recording(preprocessed_recording)

sorting_analyzer.compute({
"spike_locations": {},
"spike_amplitudes": {},
"unit_locations": {},
"correlograms": {},
"template_similarity": {},
"isi_histograms": {},
"template_metrics": {"include_multi_channel_metrics": True},
"quality_metrics": {},
})


This will take longer since you are dealing with the raw recording, but you do have a lot of control over how to compute the extensions.
1 change: 1 addition & 0 deletions doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
benchmark_with_hybrid_recordings
auto_curation_training
auto_curation_prediction
import_kilosort_data
3 changes: 1 addition & 2 deletions src/spikeinterface/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from .toy_example import toy_example as toy_example
from .bids import read_bids as read_bids


from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts

from .neoextractors import get_neo_num_blocks, get_neo_streams
from .phykilosortextractors import read_kilosort_as_analyzer

from warnings import warn

Expand Down
189 changes: 188 additions & 1 deletion src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,25 @@

from typing import Optional
from pathlib import Path
import warnings

import numpy as np

from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python
from spikeinterface.core import (
BaseSorting,
BaseSortingSegment,
read_python,
generate_ground_truth_recording,
ChannelSparsity,
ComputeTemplates,
create_sorting_analyzer,
SortingAnalyzer,
)
from spikeinterface.core.core_tools import define_function_from_class

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe


class BasePhyKilosortSortingExtractor(BaseSorting):
"""Base SortingExtractor for Phy and Kilosort output folder.
Expand Down Expand Up @@ -302,3 +315,177 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove

read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy")
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer:
"""
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
above are supported. The function may work on older versions of Kilosort output,
but these are not carefully tested. Please check your output carefully.

Parameters
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.

Returns
-------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
"""

phy_path = Path(folder_path)

sorting = read_phy(phy_path)
sampling_frequency = sorting.sampling_frequency

# kilosort occasionally contains a few spikes just beyond the recording end point, which can lead
# to errors later. To avoid this, we pad the recording with an extra second of blank time.
duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1

if (phy_path / "probe.prb").is_file():
probegroup = read_prb(phy_path / "probe.prb")
if len(probegroup.probes) > 0:
warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.")
probe = probegroup.probes[0]
elif (phy_path / "channel_positions.npy").is_file():
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
else:
AssertionError(f"Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity)

# first compute random spikes. These do nothing, but are needed for si-gui to run
sorting_analyzer.compute("random_spikes")

_make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten)
_make_locations(sorting_analyzer, phy_path)

sorting_analyzer._recording = None
return sorting_analyzer


def _make_locations(sorting_analyzer, kilosort_output_path):
"""Constructs a `spike_locations` extension from the amplitudes numpy array
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""

locations_extension = ComputeSpikeLocations(sorting_analyzer)

spike_locations_path = kilosort_output_path / "spike_positions.npy"
if spike_locations_path.is_file():
locs_np = np.load(spike_locations_path)
else:
return

# Check that the spike locations vector is the same size as the spike vector
num_spikes = len(sorting_analyzer.sorting.to_spike_vector())
num_spike_locs = len(locs_np)
if num_spikes != num_spike_locs:
warnings.warn(
"The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations."
)
return

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
dtype = [(name, locs_np.dtype) for name in column_names]

structured_array = np.zeros(len(locs_np), dtype=dtype)
for coordinate_index, column_name in enumerate(column_names):
structured_array[column_name] = locs_np[:, coordinate_index]

locations_extension.data = {"spike_locations": structured_array}
locations_extension.params = {}
locations_extension.run_info = {"run_completed": True}

sorting_analyzer.extensions["spike_locations"] = locations_extension


def _make_sparsity_from_templates(sorting, recording, kilosort_output_path):
"""Constructs the `ChannelSparsity` of from kilosort output, by seeing if the
templates output is zero or not on all channels."""

templates = np.load(kilosort_output_path / "templates.npy")

unit_ids = sorting.unit_ids
channel_ids = recording.channel_ids

# The raw templates have dense dimensions (num chan)x(num samples)x(num units)
# but are zero on many channels, which implicitly defines the sparsity
mask = np.sum(np.abs(templates), axis=1) != 0
return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids)


def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True):
"""Constructs a `templates` extension from the amplitudes numpy array
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""

template_extension = ComputeTemplates(sorting_analyzer)

whitened_templates = np.load(kilosort_output_path / "templates.npy")
wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy")
new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates

template_extension.data = {"average": new_templates}

ops_path = kilosort_output_path / "ops.npy"
if ops_path.is_file():
ops = np.load(ops_path, allow_pickle=True)

number_samples_before_template_peak = ops.item(0)["nt0min"]
total_template_samples = ops.item(0)["nt"]

number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak

ms_before = number_samples_before_template_peak / (sampling_frequency // 1000)
ms_after = number_samples_after_template_peak / (sampling_frequency // 1000)

# Used for kilosort 2, 2.5 and 3
else:

warnings.warn("Can't extract `ms_before` and `ms_after` from Kilosort output. Guessing a sensible value.")

samples_in_templates = np.shape(new_templates)[1]
template_extent_ms = (samples_in_templates + 1) / (sampling_frequency // 1000)
ms_before = template_extent_ms / 3
ms_after = 2 * template_extent_ms / 3

params = {
"operators": ["average"],
"ms_before": ms_before,
"ms_after": ms_after,
"peak_sign": "both",
}

template_extension.params = params
template_extension.run_info = {"run_completed": True}

sorting_analyzer.extensions["templates"] = template_extension


def _compute_unwhitened_templates(whitened_templates, wh_inv):
"""Constructs unwhitened templates from whitened_templates, by
applying an inverse whitening matrix."""

# templates have dimension (num units) x (num samples) x (num channels)
# whitening inverse has dimension (num units) x (num channels)
# to undo whitening, we need do matrix multiplication on the channel index
unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates)

return unwhitened_templates