Skip to content

Commit 211c222

Browse files
authored
Merge pull request #2973 from chrishalcrow/add-tests-for-qm-structure
Add test to check unit structure in quality metric calculator output
2 parents 5a7d890 + aa314cd commit 211c222

File tree

2 files changed

+129
-5
lines changed

2 files changed

+129
-5
lines changed

src/spikeinterface/qualitymetrics/misc_metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ def compute_refrac_period_violations(
388388
nb_violations = {}
389389
rp_contamination = {}
390390

391-
for i, unit_id in enumerate(sorting.unit_ids):
391+
for unit_index, unit_id in enumerate(sorting.unit_ids):
392392
if unit_id not in unit_ids:
393393
continue
394394

395-
nb_violations[unit_id] = n_v = nb_rp_violations[i]
395+
nb_violations[unit_id] = n_v = nb_rp_violations[unit_index]
396396
N = num_spikes[unit_id]
397397
if N == 0:
398398
rp_contamination[unit_id] = np.nan
@@ -1085,10 +1085,10 @@ def compute_drift_metrics(
10851085
spikes_in_bin = spikes_in_segment[i0:i1]
10861086
spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction]
10871087

1088-
for unit_ind in np.arange(len(unit_ids)):
1089-
mask = spikes_in_bin["unit_index"] == unit_ind
1088+
for i, unit_id in enumerate(unit_ids):
1089+
mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id)
10901090
if np.sum(mask) >= min_spikes_per_interval:
1091-
median_positions[unit_ind, bin_index] = np.median(spike_locations_in_bin[mask])
1091+
median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask])
10921092
if median_position_segments is None:
10931093
median_position_segments = median_positions
10941094
else:

src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212

1313
from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions
1414

15+
from spikeinterface.qualitymetrics.quality_metric_list import (
16+
_misc_metric_name_to_func,
17+
)
1518

1619
from spikeinterface.qualitymetrics import (
20+
get_quality_metric_list,
1721
mahalanobis_metrics,
1822
lda_metrics,
1923
nearest_neighbors_metrics,
@@ -34,6 +38,7 @@
3438
compute_amplitude_cv_metrics,
3539
compute_sd_ratio,
3640
get_synchrony_counts,
41+
compute_quality_metrics,
3742
)
3843

3944
from spikeinterface.core.basesorting import minimum_spike_dtype
@@ -42,6 +47,125 @@
4247
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
4348

4449

50+
def _small_sorting_analyzer():
51+
recording, sorting = generate_ground_truth_recording(
52+
durations=[2.0],
53+
num_units=4,
54+
seed=1205,
55+
)
56+
57+
sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"])
58+
59+
sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
60+
61+
extensions_to_compute = {
62+
"random_spikes": {"seed": 1205},
63+
"noise_levels": {"seed": 1205},
64+
"waveforms": {},
65+
"templates": {},
66+
"spike_amplitudes": {},
67+
"spike_locations": {},
68+
"principal_components": {},
69+
}
70+
71+
sorting_analyzer.compute(extensions_to_compute)
72+
73+
return sorting_analyzer
74+
75+
76+
@pytest.fixture(scope="module")
77+
def small_sorting_analyzer():
78+
return _small_sorting_analyzer()
79+
80+
81+
def test_unit_structure_in_output(small_sorting_analyzer):
82+
83+
qm_params = {
84+
"presence_ratio": {"bin_duration_s": 0.1},
85+
"amplitude_cutoff": {"num_histogram_bins": 3},
86+
"amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3},
87+
"firing_range": {"bin_size_s": 1},
88+
"isi_violation": {"isi_threshold_ms": 10},
89+
"drift": {"interval_s": 1, "min_spikes_per_interval": 5},
90+
"sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15},
91+
"rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0},
92+
}
93+
94+
for metric_name in get_quality_metric_list():
95+
96+
try:
97+
qm_param = qm_params[metric_name]
98+
except:
99+
qm_param = {}
100+
101+
result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param)
102+
result_sub = _misc_metric_name_to_func[metric_name](
103+
sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param
104+
)
105+
106+
if isinstance(result_all, dict):
107+
assert list(result_all.keys()) == ["#3", "#9", "#4"]
108+
assert list(result_sub.keys()) == ["#4", "#9"]
109+
assert result_sub["#9"] == result_all["#9"]
110+
assert result_sub["#4"] == result_all["#4"]
111+
112+
else:
113+
for result_ind, result in enumerate(result_sub):
114+
115+
assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"]
116+
assert result_sub[result_ind].keys() == set(["#4", "#9"])
117+
118+
assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"]
119+
assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"]
120+
121+
122+
def test_unit_id_order_independence(small_sorting_analyzer):
123+
"""
124+
Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels,
125+
and checks that their calculated quality metrics are independent of the ordering and labelling.
126+
"""
127+
128+
recording = small_sorting_analyzer.recording
129+
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3])
130+
131+
small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
132+
133+
extensions_to_compute = {
134+
"random_spikes": {"seed": 1205},
135+
"noise_levels": {"seed": 1205},
136+
"waveforms": {},
137+
"templates": {},
138+
"spike_amplitudes": {},
139+
"spike_locations": {},
140+
"principal_components": {},
141+
}
142+
143+
small_sorting_analyzer_2.compute(extensions_to_compute)
144+
145+
# need special params to get non-nan results on a short recording
146+
qm_params = {
147+
"presence_ratio": {"bin_duration_s": 0.1},
148+
"amplitude_cutoff": {"num_histogram_bins": 3},
149+
"amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3},
150+
"firing_range": {"bin_size_s": 1},
151+
"isi_violation": {"isi_threshold_ms": 10},
152+
"drift": {"interval_s": 1, "min_spikes_per_interval": 5},
153+
"sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15},
154+
}
155+
156+
quality_metrics_1 = compute_quality_metrics(
157+
small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params
158+
)
159+
quality_metrics_2 = compute_quality_metrics(
160+
small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params
161+
)
162+
163+
for metric, metric_1_data in quality_metrics_1.items():
164+
assert quality_metrics_2[metric][3] == metric_1_data["#3"]
165+
assert quality_metrics_2[metric][2] == metric_1_data["#9"]
166+
assert quality_metrics_2[metric][0] == metric_1_data["#4"]
167+
168+
45169
def _sorting_analyzer_simple():
46170
recording, sorting = generate_ground_truth_recording(
47171
durations=[

0 commit comments

Comments
 (0)