Skip to content

Commit a96821c

Browse files
committed
Rename spike_labels -> spike_unit_indices
1 parent 2741273 commit a96821c

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

src/spikeinterface/postprocessing/correlograms.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,16 @@ def _compute_correlograms_numpy(sorting, window_size, bin_size):
262262

263263
for seg_index in range(num_seg):
264264
spike_times = spikes[seg_index]["sample_index"]
265-
spike_labels = spikes[seg_index]["unit_index"]
265+
spike_unit_indices = spikes[seg_index]["unit_index"]
266266

267-
c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size)
267+
c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size)
268268

269269
correlograms += c0
270270

271271
return correlograms
272272

273273

274-
def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size):
274+
def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size):
275275
"""
276276
A very well optimized algorithm for the cross-correlation of
277277
spike trains, copied from the Phy package, written by Cyrille Rossant.
@@ -281,7 +281,7 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size
281281
spike_times : np.ndarray
282282
An array of spike times (in samples, not seconds).
283283
This contains spikes from all units.
284-
spike_labels : np.ndarray
284+
spike_unit_indices : np.ndarray
285285
An array of labels indicating the unit of the corresponding
286286
spike in `spike_times`.
287287
window_size : int
@@ -315,7 +315,7 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size
315315
match within the window size.
316316
"""
317317
num_bins, num_half_bins = _compute_num_bins(window_size, bin_size)
318-
num_units = len(np.unique(spike_labels))
318+
num_units = len(np.unique(spike_unit_indices))
319319

320320
correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64")
321321

@@ -347,12 +347,12 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size
347347
# to be incremented, taking into account the spike unit labels.
348348
if sign == 1:
349349
indices = np.ravel_multi_index(
350-
(spike_labels[+shift:][m], spike_labels[:-shift][m], spike_diff_b[m] + num_half_bins),
350+
(spike_unit_indices[+shift:][m], spike_unit_indices[:-shift][m], spike_diff_b[m] + num_half_bins),
351351
correlograms.shape,
352352
)
353353
else:
354354
indices = np.ravel_multi_index(
355-
(spike_labels[:-shift][m], spike_labels[+shift:][m], spike_diff_b[m] + num_half_bins),
355+
(spike_unit_indices[:-shift][m], spike_unit_indices[+shift:][m], spike_diff_b[m] + num_half_bins),
356356
correlograms.shape,
357357
)
358358

@@ -411,12 +411,12 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
411411

412412
for seg_index in range(sorting.get_num_segments()):
413413
spike_times = spikes[seg_index]["sample_index"]
414-
spike_labels = spikes[seg_index]["unit_index"]
414+
spike_unit_indices = spikes[seg_index]["unit_index"]
415415

416416
_compute_correlograms_one_segment_numba(
417417
correlograms,
418418
spike_times.astype(np.int64, copy=False),
419-
spike_labels.astype(np.int32, copy=False),
419+
spike_unit_indices.astype(np.int32, copy=False),
420420
window_size,
421421
bin_size,
422422
num_half_bins,
@@ -433,7 +433,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
433433
cache=False,
434434
)
435435
def _compute_correlograms_one_segment_numba(
436-
correlograms, spike_times, spike_labels, window_size, bin_size, num_half_bins
436+
correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins
437437
):
438438
"""
439439
Compute the correlograms using `numba` for speed.
@@ -455,7 +455,7 @@ def _compute_correlograms_one_segment_numba(
455455
spike_times : np.ndarray
456456
An array of spike times (in samples, not seconds).
457457
This contains spikes from all units.
458-
spike_labels : np.ndarray
458+
spike_unit_indices : np.ndarray
459459
An array of labels indicating the unit of the corresponding
460460
spike in `spike_times`.
461461
window_size : int
@@ -494,4 +494,4 @@ def _compute_correlograms_one_segment_numba(
494494

495495
bin = diff // bin_size
496496

497-
correlograms[spike_labels[i], spike_labels[j], num_half_bins + bin] += 1
497+
correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1

src/spikeinterface/postprocessing/tests/test_correlograms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,18 @@ def test_compute_correlograms(fill_all_bins, on_time_bin, multi_segment):
189189
counts should double when two segments with identical spike times / labels are used.
190190
"""
191191
sampling_frequency = 30000
192-
window_ms, bin_ms, spike_times, spike_labels, expected_bins, expected_result_auto, expected_result_corr = (
192+
window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr = (
193193
generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, on_time_bin)
194194
)
195195

196196
if multi_segment:
197197
sorting = NumpySorting.from_times_labels(
198-
times_list=[spike_times], labels_list=[spike_labels], sampling_frequency=sampling_frequency
198+
times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency
199199
)
200200
else:
201201
sorting = NumpySorting.from_times_labels(
202202
times_list=[spike_times, spike_times],
203-
labels_list=[spike_labels, spike_labels],
203+
labels_list=[spike_unit_indices, spike_unit_indices],
204204
sampling_frequency=sampling_frequency,
205205
)
206206
expected_result_auto *= 2
@@ -235,13 +235,13 @@ def test_compute_correlograms_different_units(method):
235235
spike_times = np.array([0, 4, 8, 16]) / 1000 * sampling_frequency
236236
spike_times.astype(int)
237237

238-
spike_labels = np.array([0, 1, 0, 1])
238+
spike_unit_indices = np.array([0, 1, 0, 1])
239239

240240
window_ms = 40
241241
bin_ms = 5
242242

243243
sorting = NumpySorting.from_times_labels(
244-
times_list=[spike_times], labels_list=[spike_labels], sampling_frequency=sampling_frequency
244+
times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency
245245
)
246246

247247
result, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method)
@@ -323,7 +323,7 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin
323323
# Now, make a set of times that increase by `base_diff_time` e.g.
324324
# if base_diff_time=0.0051 then our spike times are [`0.0051, 0.0102, ...]`
325325
spike_times = np.repeat(np.arange(num_filled_bins), num_units) * base_diff_time
326-
spike_labels = np.tile(np.arange(num_units), int(spike_times.size / num_units))
326+
spike_unit_indices = np.tile(np.arange(num_units), int(spike_times.size / num_units))
327327

328328
spike_times *= sampling_frequency
329329
spike_times = spike_times.astype(int)
@@ -368,4 +368,4 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin
368368
expected_result_corr = expected_result_auto.copy()
369369
expected_result_corr[int(num_bins / 2)] = num_filled_bins
370370

371-
return window_ms, bin_ms, spike_times, spike_labels, expected_bins, expected_result_auto, expected_result_corr
371+
return window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr

0 commit comments

Comments
 (0)