Skip to content

Commit b7c4309

Browse files
committed
units aggergation should preserve ids
1 parent bd9cd1f commit b7c4309

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

src/spikeinterface/core/tests/test_unitsaggregationsorting.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from spikeinterface.core import NpzSortingExtractor
77
from spikeinterface.core import create_sorting_npz
8+
from spikeinterface.core import generate_sorting
89

910

1011
def test_unitsaggregationsorting(create_cache_folder):
@@ -33,10 +34,12 @@ def test_unitsaggregationsorting(create_cache_folder):
3334
spiketrain1_1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=seg)
3435
spiketrains2_0 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=seg)
3536
spiketrains3_2 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=seg)
36-
assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=seg))
37-
assert np.allclose(spiketrains2_0, sorting_agg.get_unit_spike_train(num_units + unit_ids[0], segment_index=seg))
37+
assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(str(unit_ids[1]), segment_index=seg))
3838
assert np.allclose(
39-
spiketrains3_2, sorting_agg.get_unit_spike_train(2 * num_units + unit_ids[2], segment_index=seg)
39+
spiketrains2_0, sorting_agg.get_unit_spike_train(str(num_units + unit_ids[0]), segment_index=seg)
40+
)
41+
assert np.allclose(
42+
spiketrains3_2, sorting_agg.get_unit_spike_train(str(2 * num_units + unit_ids[2]), segment_index=seg)
4043
)
4144

4245
# test rename units
@@ -92,5 +95,42 @@ def test_unitsaggregationsorting(create_cache_folder):
9295
print(sorting_agg_prop.get_property("brain_area"))
9396

9497

98+
def test_unit_aggregation_preserve_ids():
99+
100+
sorting1 = generate_sorting(num_units=3)
101+
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
102+
103+
sorting2 = generate_sorting(num_units=3)
104+
sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"])
105+
106+
aggregated_sorting = aggregate_units([sorting1, sorting2])
107+
assert aggregated_sorting.get_num_units() == 6
108+
assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"]
109+
110+
111+
def test_unit_aggregation_does_not_preserve_ids_if_not_unique():
112+
sorting1 = generate_sorting(num_units=3)
113+
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
114+
115+
sorting2 = generate_sorting(num_units=3)
116+
sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
117+
118+
aggregated_sorting = aggregate_units([sorting1, sorting2])
119+
assert aggregated_sorting.get_num_units() == 6
120+
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"]
121+
122+
123+
def test_unit_aggregation_does_not_preserve_ids_not_the_same_type():
124+
sorting1 = generate_sorting(num_units=3)
125+
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
126+
127+
sorting2 = generate_sorting(num_units=2)
128+
sorting2 = sorting2.rename_units(new_unit_ids=[1, 2])
129+
130+
aggregated_sorting = aggregate_units([sorting1, sorting2])
131+
assert aggregated_sorting.get_num_units() == 5
132+
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"]
133+
134+
95135
if __name__ == "__main__":
96136
test_unitsaggregationsorting()

src/spikeinterface/core/unitsaggregationsorting.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,17 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
3434
)
3535
unit_ids = list(renamed_unit_ids)
3636
else:
37-
unit_ids = list(np.arange(num_all_units))
37+
all_ids_are_same_type = np.unique([sort.get_unit_ids().dtype for sort in sorting_list]).size == 1
38+
all_units_ids_are_unique = False
39+
if all_ids_are_same_type:
40+
combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list])
41+
all_units_ids_are_unique = np.unique(combined_ids).size == num_all_units
42+
43+
if all_ids_are_same_type and all_units_ids_are_unique:
44+
unit_ids = combined_ids
45+
else:
46+
default_unit_ids = [str(i) for i in range(num_all_units)]
47+
unit_ids = default_unit_ids
3848

3949
# unit map maps unit ids that are used to get spike trains
4050
u_id = 0

0 commit comments

Comments
 (0)