|
| 1 | +# --- |
| 2 | +# jupyter: |
| 3 | +# jupytext: |
| 4 | +# cell_metadata_filter: -all |
| 5 | +# formats: ipynb,py |
| 6 | +# text_representation: |
| 7 | +# extension: .py |
| 8 | +# format_name: light |
| 9 | +# format_version: '1.5' |
| 10 | +# jupytext_version: 1.16.2 |
| 11 | +# kernelspec: |
| 12 | +# display_name: Python 3 (ipykernel) |
| 13 | +# language: python |
| 14 | +# name: python3 |
| 15 | +# --- |
| 16 | + |
| 17 | +# # Benchmark spike sorting with hybrid recordings |
| 18 | +# |
| 19 | +# This example shows how to use the SpikeInterface hybrid recordings framework to benchmark spike sorting results. |
| 20 | +# |
| 21 | +# Hybrid recordings are built from existing recordings by injecting units with known spiking activity. |
| 22 | +# The template (aka average waveforms) of the injected units can be from previous spike sorted data. |
| 23 | +# In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on [DANDI](https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9)). |
| 24 | +# |
| 25 | +# Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in order to smoothly inject spikes into the recording. |
| 26 | + |
| 27 | +# + |
| 28 | +import spikeinterface as si |
| 29 | +import spikeinterface.extractors as se |
| 30 | +import spikeinterface.preprocessing as spre |
| 31 | +import spikeinterface.comparison as sc |
| 32 | +import spikeinterface.generation as sgen |
| 33 | +import spikeinterface.widgets as sw |
| 34 | + |
| 35 | +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion |
| 36 | + |
| 37 | +import numpy as np |
| 38 | +import matplotlib.pyplot as plt |
| 39 | +from pathlib import Path |
| 40 | +# - |
| 41 | + |
| 42 | +# %matplotlib inline |
| 43 | + |
| 44 | +si.set_global_job_kwargs(n_jobs=16) |
| 45 | + |
| 46 | +# For this notebook, we will use a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator. |
| 47 | + |
| 48 | +workdir = Path("/ssd980/working/hybrid/steinmetz_imposed_motion") |
| 49 | +workdir.mkdir(exist_ok=True) |
| 50 | + |
| 51 | +recording_np1_imposed = se.read_spikeglx("/hdd1/data/spikeglx/nick-steinmetz/dataset1/p1_g0_t0/") |
| 52 | +recording_preproc = spre.highpass_filter(recording_np1_imposed) |
| 53 | +recording_preproc = spre.common_reference(recording_preproc) |
| 54 | + |
| 55 | +# To visualize the drift, we can estimate the motion and plot it: |
| 56 | + |
| 57 | +# to correct for drift, we need a float dtype |
| 58 | +recording_preproc = spre.astype(recording_preproc, "float") |
| 59 | +_, motion_info = spre.correct_motion( |
| 60 | + recording_preproc, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True |
| 61 | +) |
| 62 | + |
| 63 | +ax = sw.plot_drift_raster_map( |
| 64 | + peaks=motion_info["peaks"], |
| 65 | + peak_locations=motion_info["peak_locations"], |
| 66 | + recording=recording_preproc, |
| 67 | + cmap="Greys_r", |
| 68 | + scatter_decimate=10, |
| 69 | + depth_lim=(-10, 3000) |
| 70 | +) |
| 71 | + |
| 72 | +# ## Retrieve templates from database |
| 73 | + |
| 74 | +# + |
| 75 | +templates_info = sgen.fetch_templates_database_info() |
| 76 | + |
| 77 | +print(f"Number of templates in database: {len(templates_info)}") |
| 78 | +print(f"Template database columns: {templates_info.columns}") |
| 79 | +# - |
| 80 | + |
| 81 | +available_brain_areas = np.unique(templates_info.brain_area) |
| 82 | +print(f"Available brain areas: {available_brain_areas}") |
| 83 | + |
| 84 | +# Let's perform a query: templates from visual brain regions and at the "top" of the probe |
| 85 | + |
| 86 | +target_area = ["VISa5", "VISa6a", "VISp5", "VISp6a", "VISrl6b"] |
| 87 | +minimum_depth = 1500 |
| 88 | +templates_selected_info = templates_info.query(f"brain_area in {target_area} and depth_along_probe > {minimum_depth}") |
| 89 | +len(templates_selected_info) |
| 90 | + |
| 91 | +# We can now retrieve the selected templates as a `Templates` object: |
| 92 | + |
| 93 | +templates_selected = sgen.query_templates_from_database(templates_selected_info, verbose=True) |
| 94 | +print(templates_selected) |
| 95 | + |
| 96 | +# While we selected templates from a target aread and at certain depths, we can see that the template amplitudes are quite large. This will make spike sorting easy... we can further manipulate the `Templates` by rescaling, relocating, or further selections with the `sgen.scale_template_to_range`, `sgen.relocate_templates`, and `sgen.select_templates` functions. |
| 97 | +# |
| 98 | +# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them towards the bottom half of the probe, where the activity looks interesting! |
| 99 | + |
| 100 | +# + |
| 101 | +min_amplitude = 50 |
| 102 | +max_amplitude = 150 |
| 103 | +templates_scaled = sgen.scale_template_to_range( |
| 104 | + templates=templates_selected, |
| 105 | + min_amplitude=min_amplitude, |
| 106 | + max_amplitude=max_amplitude |
| 107 | +) |
| 108 | + |
| 109 | +min_displacement = 1000 |
| 110 | +max_displacement = 3000 |
| 111 | +templates_relocated = sgen.relocate_templates( |
| 112 | + templates=templates_scaled, |
| 113 | + min_displacement=min_displacement, |
| 114 | + max_displacement=max_displacement |
| 115 | +) |
| 116 | +# - |
| 117 | + |
| 118 | +# Let's plot the selected templates: |
| 119 | + |
| 120 | +sparsity_plot = si.compute_sparsity(templates_relocated) |
| 121 | +fig = plt.figure(figsize=(10, 10)) |
| 122 | +w = sw.plot_unit_templates(templates_relocated, sparsity=sparsity_plot, ncols=4, figure=fig) |
| 123 | +w.figure.subplots_adjust(wspace=0.5, hspace=0.7) |
| 124 | + |
| 125 | +# ## Constructing hybrid recordings |
| 126 | +# |
| 127 | +# We can construct now hybrid recordings with the selected templates. |
| 128 | +# |
| 129 | +# We will do this in two ways to show how important it is to account for drifts when injecting hybrid spikes. |
| 130 | +# |
| 131 | +# - For the first recording we will not pass the estimated motion (`recording_hybrid_ignore_drift`). |
| 132 | +# - For the second recording, we will pass and account for the estimated motion (`recording_hybrid_with_drift`). |
| 133 | + |
| 134 | +recording_hybrid_ignore_drift, sorting_hybrid = sgen.generate_hybrid_recording( |
| 135 | + recording=recording_preproc, templates=templates_relocated, seed=2308 |
| 136 | +) |
| 137 | +recording_hybrid_ignore_drift |
| 138 | + |
| 139 | +# Note that the `generate_hybrid_recording` is warning us that we might want to account for drift! |
| 140 | + |
| 141 | +# by passing the `sorting_hybrid` object, we make sure that injected spikes are the same |
| 142 | +# this will take a bit more time because it's interpolating the templates to account for drifts |
| 143 | +recording_hybrid_with_drift, sorting_hybrid = sgen.generate_hybrid_recording( |
| 144 | + recording=recording_preproc, |
| 145 | + templates=templates_relocated, |
| 146 | + motion=motion_info["motion"], |
| 147 | + sorting=sorting_hybrid, |
| 148 | + seed=2308, |
| 149 | +) |
| 150 | +recording_hybrid_with_drift |
| 151 | + |
| 152 | +# We can use the `SortingAnalyzer` to estimate spike locations and plot them: |
| 153 | + |
| 154 | +# + |
| 155 | +# construct analyzers and compute spike locations |
| 156 | +analyzer_hybrid_ignore_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_ignore_drift) |
| 157 | +analyzer_hybrid_ignore_drift.compute(["random_spikes", "templates"]) |
| 158 | +analyzer_hybrid_ignore_drift.compute("spike_locations", method="grid_convolution") |
| 159 | + |
| 160 | +analyzer_hybrid_with_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_with_drift) |
| 161 | +analyzer_hybrid_with_drift.compute(["random_spikes", "templates"]) |
| 162 | +analyzer_hybrid_with_drift.compute("spike_locations", method="grid_convolution") |
| 163 | +# - |
| 164 | + |
| 165 | +# Let's plot the added hybrid spikes using the drift maps: |
| 166 | + |
| 167 | +fig, axs = plt.subplots(ncols=2, figsize=(10, 7), sharex=True, sharey=True) |
| 168 | +_ = sw.plot_drift_raster_map( |
| 169 | + peaks=motion_info["peaks"], |
| 170 | + peak_locations=motion_info["peak_locations"], |
| 171 | + recording=recording_preproc, |
| 172 | + cmap="Greys_r", |
| 173 | + scatter_decimate=10, |
| 174 | + ax=axs[0], |
| 175 | +) |
| 176 | +_ = sw.plot_drift_raster_map( |
| 177 | + sorting_analyzer=analyzer_hybrid_ignore_drift, |
| 178 | + color_amplitude=False, |
| 179 | + color="r", |
| 180 | + scatter_decimate=10, |
| 181 | + ax=axs[0] |
| 182 | +) |
| 183 | +_ = sw.plot_drift_raster_map( |
| 184 | + peaks=motion_info["peaks"], |
| 185 | + peak_locations=motion_info["peak_locations"], |
| 186 | + recording=recording_preproc, |
| 187 | + cmap="Greys_r", |
| 188 | + scatter_decimate=10, |
| 189 | + ax=axs[1], |
| 190 | +) |
| 191 | +_ = sw.plot_drift_raster_map( |
| 192 | + sorting_analyzer=analyzer_hybrid_with_drift, |
| 193 | + color_amplitude=False, |
| 194 | + color="b", |
| 195 | + scatter_decimate=10, |
| 196 | + ax=axs[1] |
| 197 | +) |
| 198 | +axs[0].set_title("Hybrid spikes\nIgnoring drift") |
| 199 | +axs[1].set_title("Hybrid spikes\nAccounting for drift") |
| 200 | +axs[0].set_xlim(1000, 1500) |
| 201 | +axs[0].set_ylim(500, 2500) |
| 202 | + |
| 203 | +# We can see that clearly following drift is essential in order to properly blend the hybrid spikes into the recording! |
| 204 | + |
| 205 | +# ## Ground-truth study |
| 206 | +# |
| 207 | +# In this section we will use the hybrid recording to benchmark a few spike sorters: |
| 208 | +# |
| 209 | +# - `Kilosort2.5` |
| 210 | +# - `Kilosort3` |
| 211 | +# - `Kilosort4` |
| 212 | +# - `Spyking-CIRCUS 2` |
| 213 | + |
| 214 | +# to speed up computations, let's first dump the recording to binary |
| 215 | +recording_hybrid_bin = recording_hybrid_with_drift.save( |
| 216 | + folder=workdir / "hybrid_bin", |
| 217 | + overwrite=True |
| 218 | +) |
| 219 | + |
| 220 | +# + |
| 221 | +datasets = { |
| 222 | + "hybrid": (recording_hybrid_bin, sorting_hybrid), |
| 223 | +} |
| 224 | + |
| 225 | +cases = { |
| 226 | + ("kilosort2.5", "hybrid"): { |
| 227 | + "label": "KS2.5", |
| 228 | + "dataset": "hybrid", |
| 229 | + "run_sorter_params": { |
| 230 | + "sorter_name": "kilosort2_5", |
| 231 | + }, |
| 232 | + }, |
| 233 | + ("kilosort3", "hybrid"): { |
| 234 | + "label": "KS3", |
| 235 | + "dataset": "hybrid", |
| 236 | + "run_sorter_params": { |
| 237 | + "sorter_name": "kilosort3", |
| 238 | + }, |
| 239 | + }, |
| 240 | + ("kilosort4", "hybrid"): { |
| 241 | + "label": "KS4", |
| 242 | + "dataset": "hybrid", |
| 243 | + "run_sorter_params": {"sorter_name": "kilosort4", "nblocks": 5}, |
| 244 | + }, |
| 245 | + ("sc2", "hybrid"): { |
| 246 | + "label": "spykingcircus2", |
| 247 | + "dataset": "hybrid", |
| 248 | + "run_sorter_params": { |
| 249 | + "sorter_name": "spykingcircus2", |
| 250 | + }, |
| 251 | + }, |
| 252 | +} |
| 253 | + |
| 254 | +# + |
| 255 | +study_folder = workdir / "gt_study" |
| 256 | + |
| 257 | +gtstudy = sc.GroundTruthStudy(study_folder) |
| 258 | + |
| 259 | +# - |
| 260 | + |
| 261 | +# run the spike sorting jobs |
| 262 | +gtstudy.run_sorters(verbose=False, keep=True) |
| 263 | + |
| 264 | +# run the comparisons |
| 265 | +gtstudy.run_comparisons(exhaustive_gt=False) |
| 266 | + |
| 267 | +# ## Plot performances |
| 268 | +# |
| 269 | +# Given that we know the exactly where we injected the hybrid spikes, we can now compute and plot performance metrics: accuracy, precision, and recall. |
| 270 | +# |
| 271 | +# In the following plot, the x axis is the unit index, while the y axis is the performance metric. The units are sorted by performance. |
| 272 | + |
| 273 | +w_perf = sw.plot_study_performances(gtstudy, figsize=(12, 7)) |
| 274 | +w_perf.axes[0, 0].legend(loc=4) |
| 275 | + |
| 276 | +# From the performance plots, we can see that there is no clear "winner", but `Kilosort3` definitely performs worse than the other options. |
| 277 | +# |
| 278 | +# Although non of the sorters find all units perfectly, `Kilosort2.5`, `Kilosort4`, and `SpyKING CIRCUS 2` all find around 10-12 hybrid units with accuracy greater than 80%. |
| 279 | +# `Kilosort4` has a better overall curve, being able to find almost all units with an accuracy above 50%. `Kilosort2.5` performs well when looking at precision (finding all spikes in a hybrid unit), at the cost of lower recall (finding spikes when it shouldn't). |
| 280 | +# |
| 281 | +# |
| 282 | +# In this example, we showed how to: |
| 283 | +# |
| 284 | +# - Access and fetch templates from the SpikeInterface template database |
| 285 | +# - Manipulate templates (scaling/relocating) |
| 286 | +# - Construct hybrid recordings accounting for drifts |
| 287 | +# - Use the `GroundTruthStudy` to benchmark different sorters |
| 288 | +# |
| 289 | +# The hybrid framework can be extended to target multiple recordings from different brain regions and species and creating recordings of increasing complexity to challenge the existing sorters! |
| 290 | +# |
| 291 | +# In addition, hybrid studies can also be used to fine-tune spike sorting parameters on specific datasets. |
| 292 | +# |
| 293 | +# **Are you ready to try it on your data?** |
0 commit comments