Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/spikeinterface/metrics/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,85 @@
import pytest

from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer
from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
)

job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def make_small_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=10,
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {"operators": ["average", "median"]},
"spike_amplitudes": {},
"spike_locations": {},
"principal_components": {},
}

sorting_analyzer.compute(extensions_to_compute)

return sorting_analyzer


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()
return make_small_analyzer()


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs)

return sorting_analyzer
47 changes: 47 additions & 0 deletions src/spikeinterface/metrics/quality/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,52 @@ class AmplitudeMedian(BaseMetric):
depend_on = ["spike_amplitudes"]


def compute_waveform_ptp_medians(sorting_analyzer, unit_ids=None):
"""
Compute median of the peak-to-peak (PTP) values of the waveforms.

Parameters
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
unit_ids : list or None
List of unit ids to compute the waveform PTP medians. If None, all units are used.

Returns
-------
all_waveform_ptp_medians : dict
Estimated waveform PTP median for each unit ID.

References
----------
Inspired by bombcell folks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe should be a more explicit ref to the bombcel paper?

"""
if unit_ids is None:
unit_ids = sorting_analyzer.unit_ids

check_has_required_extensions("waveform_ptp_median", sorting_analyzer)

wfs_ext = sorting_analyzer.get_extension("waveforms")
extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index", mode="peak_to_peak")
all_waveform_ptp_medians = {}

for unit_id in unit_ids:
waveforms = wfs_ext.get_waveforms_one_unit(unit_id, force_dense=True)
waveform_max_channel = waveforms[:, :, extremum_channel_indices[unit_id]]
ptps = np.ptp(waveform_max_channel, axis=1)
median_ptp = np.median(ptps)
all_waveform_ptp_medians[unit_id] = median_ptp

return all_waveform_ptp_medians


class WaveformPTPMedian(BaseMetric):
metric_name = "waveform_ptp_median"
metric_function = compute_waveform_ptp_medians
metric_columns = {"waveform_ptp_median": float}
depend_on = ["waveforms"]


def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100):
"""
A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution.
Expand Down Expand Up @@ -1258,6 +1304,7 @@ class SDRatio(BaseMetric):
AmplitudeCutoff,
NoiseCutoff,
AmplitudeMedian,
WaveformPTPMedian,
Drift,
SDRatio,
]
Expand Down
85 changes: 0 additions & 85 deletions src/spikeinterface/metrics/quality/tests/conftest.py

This file was deleted.

27 changes: 7 additions & 20 deletions src/spikeinterface/metrics/quality/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from pathlib import Path
import numpy as np
from copy import deepcopy
import csv
Expand All @@ -14,26 +13,21 @@

from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions

# from spikeinterface.metrics.quality_metric_list import (
# _misc_metric_name_to_func,
# )

from spikeinterface.metrics.quality import (
get_quality_metric_list,
get_quality_pca_metric_list,
compute_quality_metrics,
)
from spikeinterface.metrics.quality.misc_metrics import (
misc_metrics_list,
compute_amplitude_cutoffs,
compute_presence_ratios,
compute_isi_violations,
# compute_firing_rates,
# compute_num_spikes,
compute_snrs,
compute_refrac_period_violations,
compute_sliding_rp_violations,
compute_drift_metrics,
compute_waveform_ptp_medians,
compute_amplitude_medians,
compute_synchrony_metrics,
compute_firing_ranges,
Expand All @@ -44,7 +38,6 @@
)

from spikeinterface.metrics.quality.pca_metrics import (
pca_metrics_list,
mahalanobis_metrics,
d_prime_metric,
nearest_neighbors_metrics,
Expand Down Expand Up @@ -486,18 +479,6 @@ def test_simplified_silhouette_score_metrics():
assert sim_sil_score1 < sim_sil_score2


# def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple):
# sorting_analyzer = sorting_analyzer_simple
# firing_rates = compute_firing_rates(sorting_analyzer)
# num_spikes = compute_num_spikes(sorting_analyzer)

# testing method accuracy with magic number is not a good pratcice, I remove this.
# firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09}
# num_spikes_gt = {0: 1001, 1: 503, 2: 509}
# assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05)
# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values()))


def test_calculate_firing_range(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
firing_ranges = compute_firing_ranges(sorting_analyzer)
Expand Down Expand Up @@ -532,6 +513,12 @@ def test_calculate_amplitude_median(sorting_analyzer_simple):
# assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05)


def test_calculate_waveform_ptp_median(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
# spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data()
wfs_ptp_medians = compute_waveform_ptp_medians(sorting_analyzer)


def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_analyzer, average_num_spikes_per_bin=20)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
from spikeinterface.metrics.spiketrain import (
compute_firing_rates,
compute_num_spikes,
)


def test_calculate_firing_ratess(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
firing_rates = compute_firing_rates(sorting_analyzer)
assert np.all(np.array(list(firing_rates.values())) > 0)


def test_calculate_num_spikes(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
num_spikes = compute_num_spikes(sorting_analyzer)
assert np.all(np.array(list(num_spikes.values())) > 0)