diff --git a/src/spikeinterface/metrics/conftest.py b/src/spikeinterface/metrics/conftest.py index 8d32c103fa..c2a6c6fe82 100644 --- a/src/spikeinterface/metrics/conftest.py +++ b/src/spikeinterface/metrics/conftest.py @@ -1,8 +1,85 @@ import pytest -from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +def make_small_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer @pytest.fixture(scope="module") def small_sorting_analyzer(): - return _small_sorting_analyzer() + return make_small_analyzer() + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) + + return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c4d8941ccc..17a16742ba 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -851,6 +851,52 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] +def compute_waveform_ptp_medians(sorting_analyzer, unit_ids=None): + """ + Compute median of the peak-to-peak (PTP) values of the waveforms. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + List of unit ids to compute the waveform PTP medians. If None, all units are used. + + Returns + ------- + all_waveform_ptp_medians : dict + Estimated waveform PTP median for each unit ID. + + References + ---------- + Inspired by bombcell folks + """ + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + check_has_required_extensions("waveform_ptp_median", sorting_analyzer) + + wfs_ext = sorting_analyzer.get_extension("waveforms") + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index", mode="peak_to_peak") + all_waveform_ptp_medians = {} + + for unit_id in unit_ids: + waveforms = wfs_ext.get_waveforms_one_unit(unit_id, force_dense=True) + waveform_max_channel = waveforms[:, :, extremum_channel_indices[unit_id]] + ptps = np.ptp(waveform_max_channel, axis=1) + median_ptp = np.median(ptps) + all_waveform_ptp_medians[unit_id] = median_ptp + + return all_waveform_ptp_medians + + +class WaveformPTPMedian(BaseMetric): + metric_name = "waveform_ptp_median" + metric_function = compute_waveform_ptp_medians + metric_columns = {"waveform_ptp_median": float} + depend_on = ["waveforms"] + + def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -1258,6 +1304,7 @@ class SDRatio(BaseMetric): AmplitudeCutoff, NoiseCutoff, AmplitudeMedian, + WaveformPTPMedian, Drift, SDRatio, ] diff --git a/src/spikeinterface/metrics/quality/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py deleted file mode 100644 index c2a6c6fe82..0000000000 --- a/src/spikeinterface/metrics/quality/tests/conftest.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest - -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def make_small_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=10, - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {"operators": ["average", "median"]}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return make_small_analyzer() - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) - - return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c0dd6c6033..38dd6fd762 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from copy import deepcopy import csv @@ -14,13 +13,9 @@ from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions -# from spikeinterface.metrics.quality_metric_list import ( -# _misc_metric_name_to_func, -# ) from spikeinterface.metrics.quality import ( get_quality_metric_list, - get_quality_pca_metric_list, compute_quality_metrics, ) from spikeinterface.metrics.quality.misc_metrics import ( @@ -28,12 +23,11 @@ compute_amplitude_cutoffs, compute_presence_ratios, compute_isi_violations, - # compute_firing_rates, - # compute_num_spikes, compute_snrs, compute_refrac_period_violations, compute_sliding_rp_violations, compute_drift_metrics, + compute_waveform_ptp_medians, compute_amplitude_medians, compute_synchrony_metrics, compute_firing_ranges, @@ -44,7 +38,6 @@ ) from spikeinterface.metrics.quality.pca_metrics import ( - pca_metrics_list, mahalanobis_metrics, d_prime_metric, nearest_neighbors_metrics, @@ -486,18 +479,6 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -# def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): -# sorting_analyzer = sorting_analyzer_simple -# firing_rates = compute_firing_rates(sorting_analyzer) -# num_spikes = compute_num_spikes(sorting_analyzer) - -# testing method accuracy with magic number is not a good pratcice, I remove this. -# firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} -# num_spikes_gt = {0: 1001, 1: 503, 2: 509} -# assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) -# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) - - def test_calculate_firing_range(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple firing_ranges = compute_firing_ranges(sorting_analyzer) @@ -532,6 +513,12 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) +def test_calculate_waveform_ptp_median(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + wfs_ptp_medians = compute_waveform_ptp_medians(sorting_analyzer) + + def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_analyzer, average_num_spikes_per_bin=20) diff --git a/src/spikeinterface/metrics/spiketrain/tests/test_spiketrain_metrics.py b/src/spikeinterface/metrics/spiketrain/tests/test_spiketrain_metrics.py new file mode 100644 index 0000000000..5e5e18490b --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/tests/test_spiketrain_metrics.py @@ -0,0 +1,17 @@ +import numpy as np +from spikeinterface.metrics.spiketrain import ( + compute_firing_rates, + compute_num_spikes, +) + + +def test_calculate_firing_ratess(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + firing_rates = compute_firing_rates(sorting_analyzer) + assert np.all(np.array(list(firing_rates.values())) > 0) + + +def test_calculate_num_spikes(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + num_spikes = compute_num_spikes(sorting_analyzer) + assert np.all(np.array(list(num_spikes.values())) > 0)