Skip to content

Commit 3dd5f10

Browse files
authored
Merge branch 'main' into fix_paths_in_json
2 parents 815f605 + 4539550 commit 3dd5f10

38 files changed

+4400
-750
lines changed

doc/how_to/benchmark_with_hybrid_recordings.rst

Lines changed: 2552 additions & 0 deletions
Large diffs are not rendered by default.
179 KB
Loading
384 KB
Loading
79.7 KB
Loading
208 KB
Loading

doc/how_to/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
1313
combine_recordings
1414
process_by_channel_group
1515
load_your_data_into_sorting
16+
benchmark_with_hybrid_recordings

examples/how_to/README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@ with `nbconvert`. Here are the steps (in this example for the `get_started`):
1414

1515
```
1616
>>> jupytext --to notebook get_started.py
17+
>>> jupytext --set-formats ipynb,py get_started.ipynb
1718
```
1819

1920
2. Run the notebook
2021

22+
3. Sync the run notebook to the .py file:
2123

22-
3. Convert the notebook to .rst
24+
```
25+
>>> jupytext --sync get_started.ipynb
26+
```
27+
28+
4. Convert the notebook to .rst
2329

2430
```
2531
>>> jupyter nbconvert get_started.ipynb --to rst
26-
>>> jupyter nbconvert analyse_neuropixels.ipynb --to rst
2732
```
2833

29-
30-
4. Move the .rst and associated folder (e.g. `get_started.rst` and `get_started_files` folder) to the `doc/how_to`.
34+
5. Move the .rst and associated folder (e.g. `get_started.rst` and `get_started_files` folder) to the `doc/how_to`.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# name: python3
1515
# ---
1616

17-
# # Analyse Neuropixels datasets
17+
# # Analyze Neuropixels datasets
1818
#
1919
# This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing.
2020

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# cell_metadata_filter: -all
5+
# formats: ipynb,py
6+
# text_representation:
7+
# extension: .py
8+
# format_name: light
9+
# format_version: '1.5'
10+
# jupytext_version: 1.16.2
11+
# kernelspec:
12+
# display_name: Python 3 (ipykernel)
13+
# language: python
14+
# name: python3
15+
# ---
16+
17+
# # Benchmark spike sorting with hybrid recordings
18+
#
19+
# This example shows how to use the SpikeInterface hybrid recordings framework to benchmark spike sorting results.
20+
#
21+
# Hybrid recordings are built from existing recordings by injecting units with known spiking activity.
22+
# The template (aka average waveforms) of the injected units can be from previous spike sorted data.
23+
# In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on [DANDI](https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9)).
24+
#
25+
# Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in order to smoothly inject spikes into the recording.
26+
27+
# +
28+
import spikeinterface as si
29+
import spikeinterface.extractors as se
30+
import spikeinterface.preprocessing as spre
31+
import spikeinterface.comparison as sc
32+
import spikeinterface.generation as sgen
33+
import spikeinterface.widgets as sw
34+
35+
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
36+
37+
import numpy as np
38+
import matplotlib.pyplot as plt
39+
from pathlib import Path
40+
# -
41+
42+
# %matplotlib inline
43+
44+
si.set_global_job_kwargs(n_jobs=16)
45+
46+
# For this notebook, we will use a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator.
47+
48+
workdir = Path("/ssd980/working/hybrid/steinmetz_imposed_motion")
49+
workdir.mkdir(exist_ok=True)
50+
51+
recording_np1_imposed = se.read_spikeglx("/hdd1/data/spikeglx/nick-steinmetz/dataset1/p1_g0_t0/")
52+
recording_preproc = spre.highpass_filter(recording_np1_imposed)
53+
recording_preproc = spre.common_reference(recording_preproc)
54+
55+
# To visualize the drift, we can estimate the motion and plot it:
56+
57+
# to correct for drift, we need a float dtype
58+
recording_preproc = spre.astype(recording_preproc, "float")
59+
_, motion_info = spre.correct_motion(
60+
recording_preproc, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True
61+
)
62+
63+
ax = sw.plot_drift_raster_map(
64+
peaks=motion_info["peaks"],
65+
peak_locations=motion_info["peak_locations"],
66+
recording=recording_preproc,
67+
cmap="Greys_r",
68+
scatter_decimate=10,
69+
depth_lim=(-10, 3000)
70+
)
71+
72+
# ## Retrieve templates from database
73+
74+
# +
75+
templates_info = sgen.fetch_templates_database_info()
76+
77+
print(f"Number of templates in database: {len(templates_info)}")
78+
print(f"Template database columns: {templates_info.columns}")
79+
# -
80+
81+
available_brain_areas = np.unique(templates_info.brain_area)
82+
print(f"Available brain areas: {available_brain_areas}")
83+
84+
# Let's perform a query: templates from visual brain regions and at the "top" of the probe
85+
86+
target_area = ["VISa5", "VISa6a", "VISp5", "VISp6a", "VISrl6b"]
87+
minimum_depth = 1500
88+
templates_selected_info = templates_info.query(f"brain_area in {target_area} and depth_along_probe > {minimum_depth}")
89+
len(templates_selected_info)
90+
91+
# We can now retrieve the selected templates as a `Templates` object:
92+
93+
templates_selected = sgen.query_templates_from_database(templates_selected_info, verbose=True)
94+
print(templates_selected)
95+
96+
# While we selected templates from a target aread and at certain depths, we can see that the template amplitudes are quite large. This will make spike sorting easy... we can further manipulate the `Templates` by rescaling, relocating, or further selections with the `sgen.scale_template_to_range`, `sgen.relocate_templates`, and `sgen.select_templates` functions.
97+
#
98+
# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them towards the bottom half of the probe, where the activity looks interesting!
99+
100+
# +
101+
min_amplitude = 50
102+
max_amplitude = 150
103+
templates_scaled = sgen.scale_template_to_range(
104+
templates=templates_selected,
105+
min_amplitude=min_amplitude,
106+
max_amplitude=max_amplitude
107+
)
108+
109+
min_displacement = 1000
110+
max_displacement = 3000
111+
templates_relocated = sgen.relocate_templates(
112+
templates=templates_scaled,
113+
min_displacement=min_displacement,
114+
max_displacement=max_displacement
115+
)
116+
# -
117+
118+
# Let's plot the selected templates:
119+
120+
sparsity_plot = si.compute_sparsity(templates_relocated)
121+
fig = plt.figure(figsize=(10, 10))
122+
w = sw.plot_unit_templates(templates_relocated, sparsity=sparsity_plot, ncols=4, figure=fig)
123+
w.figure.subplots_adjust(wspace=0.5, hspace=0.7)
124+
125+
# ## Constructing hybrid recordings
126+
#
127+
# We can construct now hybrid recordings with the selected templates.
128+
#
129+
# We will do this in two ways to show how important it is to account for drifts when injecting hybrid spikes.
130+
#
131+
# - For the first recording we will not pass the estimated motion (`recording_hybrid_ignore_drift`).
132+
# - For the second recording, we will pass and account for the estimated motion (`recording_hybrid_with_drift`).
133+
134+
recording_hybrid_ignore_drift, sorting_hybrid = sgen.generate_hybrid_recording(
135+
recording=recording_preproc, templates=templates_relocated, seed=2308
136+
)
137+
recording_hybrid_ignore_drift
138+
139+
# Note that the `generate_hybrid_recording` is warning us that we might want to account for drift!
140+
141+
# by passing the `sorting_hybrid` object, we make sure that injected spikes are the same
142+
# this will take a bit more time because it's interpolating the templates to account for drifts
143+
recording_hybrid_with_drift, sorting_hybrid = sgen.generate_hybrid_recording(
144+
recording=recording_preproc,
145+
templates=templates_relocated,
146+
motion=motion_info["motion"],
147+
sorting=sorting_hybrid,
148+
seed=2308,
149+
)
150+
recording_hybrid_with_drift
151+
152+
# We can use the `SortingAnalyzer` to estimate spike locations and plot them:
153+
154+
# +
155+
# construct analyzers and compute spike locations
156+
analyzer_hybrid_ignore_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_ignore_drift)
157+
analyzer_hybrid_ignore_drift.compute(["random_spikes", "templates"])
158+
analyzer_hybrid_ignore_drift.compute("spike_locations", method="grid_convolution")
159+
160+
analyzer_hybrid_with_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_with_drift)
161+
analyzer_hybrid_with_drift.compute(["random_spikes", "templates"])
162+
analyzer_hybrid_with_drift.compute("spike_locations", method="grid_convolution")
163+
# -
164+
165+
# Let's plot the added hybrid spikes using the drift maps:
166+
167+
fig, axs = plt.subplots(ncols=2, figsize=(10, 7), sharex=True, sharey=True)
168+
_ = sw.plot_drift_raster_map(
169+
peaks=motion_info["peaks"],
170+
peak_locations=motion_info["peak_locations"],
171+
recording=recording_preproc,
172+
cmap="Greys_r",
173+
scatter_decimate=10,
174+
ax=axs[0],
175+
)
176+
_ = sw.plot_drift_raster_map(
177+
sorting_analyzer=analyzer_hybrid_ignore_drift,
178+
color_amplitude=False,
179+
color="r",
180+
scatter_decimate=10,
181+
ax=axs[0]
182+
)
183+
_ = sw.plot_drift_raster_map(
184+
peaks=motion_info["peaks"],
185+
peak_locations=motion_info["peak_locations"],
186+
recording=recording_preproc,
187+
cmap="Greys_r",
188+
scatter_decimate=10,
189+
ax=axs[1],
190+
)
191+
_ = sw.plot_drift_raster_map(
192+
sorting_analyzer=analyzer_hybrid_with_drift,
193+
color_amplitude=False,
194+
color="b",
195+
scatter_decimate=10,
196+
ax=axs[1]
197+
)
198+
axs[0].set_title("Hybrid spikes\nIgnoring drift")
199+
axs[1].set_title("Hybrid spikes\nAccounting for drift")
200+
axs[0].set_xlim(1000, 1500)
201+
axs[0].set_ylim(500, 2500)
202+
203+
# We can see that clearly following drift is essential in order to properly blend the hybrid spikes into the recording!
204+
205+
# ## Ground-truth study
206+
#
207+
# In this section we will use the hybrid recording to benchmark a few spike sorters:
208+
#
209+
# - `Kilosort2.5`
210+
# - `Kilosort3`
211+
# - `Kilosort4`
212+
# - `Spyking-CIRCUS 2`
213+
214+
# to speed up computations, let's first dump the recording to binary
215+
recording_hybrid_bin = recording_hybrid_with_drift.save(
216+
folder=workdir / "hybrid_bin",
217+
overwrite=True
218+
)
219+
220+
# +
221+
datasets = {
222+
"hybrid": (recording_hybrid_bin, sorting_hybrid),
223+
}
224+
225+
cases = {
226+
("kilosort2.5", "hybrid"): {
227+
"label": "KS2.5",
228+
"dataset": "hybrid",
229+
"run_sorter_params": {
230+
"sorter_name": "kilosort2_5",
231+
},
232+
},
233+
("kilosort3", "hybrid"): {
234+
"label": "KS3",
235+
"dataset": "hybrid",
236+
"run_sorter_params": {
237+
"sorter_name": "kilosort3",
238+
},
239+
},
240+
("kilosort4", "hybrid"): {
241+
"label": "KS4",
242+
"dataset": "hybrid",
243+
"run_sorter_params": {"sorter_name": "kilosort4", "nblocks": 5},
244+
},
245+
("sc2", "hybrid"): {
246+
"label": "spykingcircus2",
247+
"dataset": "hybrid",
248+
"run_sorter_params": {
249+
"sorter_name": "spykingcircus2",
250+
},
251+
},
252+
}
253+
254+
# +
255+
study_folder = workdir / "gt_study"
256+
257+
gtstudy = sc.GroundTruthStudy(study_folder)
258+
259+
# -
260+
261+
# run the spike sorting jobs
262+
gtstudy.run_sorters(verbose=False, keep=True)
263+
264+
# run the comparisons
265+
gtstudy.run_comparisons(exhaustive_gt=False)
266+
267+
# ## Plot performances
268+
#
269+
# Given that we know the exactly where we injected the hybrid spikes, we can now compute and plot performance metrics: accuracy, precision, and recall.
270+
#
271+
# In the following plot, the x axis is the unit index, while the y axis is the performance metric. The units are sorted by performance.
272+
273+
w_perf = sw.plot_study_performances(gtstudy, figsize=(12, 7))
274+
w_perf.axes[0, 0].legend(loc=4)
275+
276+
# From the performance plots, we can see that there is no clear "winner", but `Kilosort3` definitely performs worse than the other options.
277+
#
278+
# Although non of the sorters find all units perfectly, `Kilosort2.5`, `Kilosort4`, and `SpyKING CIRCUS 2` all find around 10-12 hybrid units with accuracy greater than 80%.
279+
# `Kilosort4` has a better overall curve, being able to find almost all units with an accuracy above 50%. `Kilosort2.5` performs well when looking at precision (finding all spikes in a hybrid unit), at the cost of lower recall (finding spikes when it shouldn't).
280+
#
281+
#
282+
# In this example, we showed how to:
283+
#
284+
# - Access and fetch templates from the SpikeInterface template database
285+
# - Manipulate templates (scaling/relocating)
286+
# - Construct hybrid recordings accounting for drifts
287+
# - Use the `GroundTruthStudy` to benchmark different sorters
288+
#
289+
# The hybrid framework can be extended to target multiple recordings from different brain regions and species and creating recordings of increasing complexity to challenge the existing sorters!
290+
#
291+
# In addition, hybrid studies can also be used to fine-tune spike sorting parameters on specific datasets.
292+
#
293+
# **Are you ready to try it on your data?**

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ test = [
139139
# preprocessing
140140
"ibllib>=2.36.0", # for IBL
141141

142+
# streaming templates
143+
"s3fs",
144+
142145
# tridesclous
143146
"numba",
144147
"hdbscan>=0.8.33", # Previous version had a broken wheel

0 commit comments

Comments
 (0)