Skip to content

Commit 53933e6

Browse files
authored
Merge pull request #3180 from h-mayorquin/sorting_aggregation_should_preserve_ids
Units aggregation preserve unit ids of aggregated sorters
2 parents e2b0a34 + f7018c6 commit 53933e6

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

src/spikeinterface/core/tests/test_unitsaggregationsorting.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from spikeinterface.core import NpzSortingExtractor
77
from spikeinterface.core import create_sorting_npz
8+
from spikeinterface.core import generate_sorting
89

910

1011
def test_unitsaggregationsorting(create_cache_folder):
@@ -92,5 +93,42 @@ def test_unitsaggregationsorting(create_cache_folder):
9293
print(sorting_agg_prop.get_property("brain_area"))
9394

9495

96+
def test_unit_aggregation_preserve_ids():
97+
98+
sorting1 = generate_sorting(num_units=3)
99+
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
100+
101+
sorting2 = generate_sorting(num_units=3)
102+
sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"])
103+
104+
aggregated_sorting = aggregate_units([sorting1, sorting2])
105+
assert aggregated_sorting.get_num_units() == 6
106+
assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"]
107+
108+
109+
def test_unit_aggregation_does_not_preserve_ids_if_not_unique():
110+
sorting1 = generate_sorting(num_units=3)
111+
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
112+
113+
sorting2 = generate_sorting(num_units=3)
114+
sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
115+
116+
aggregated_sorting = aggregate_units([sorting1, sorting2])
117+
assert aggregated_sorting.get_num_units() == 6
118+
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"]
119+
120+
121+
def test_unit_aggregation_does_not_preserve_ids_not_the_same_type():
122+
sorting1 = generate_sorting(num_units=3)
123+
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])
124+
125+
sorting2 = generate_sorting(num_units=2)
126+
sorting2 = sorting2.rename_units(new_unit_ids=[1, 2])
127+
128+
aggregated_sorting = aggregate_units([sorting1, sorting2])
129+
assert aggregated_sorting.get_num_units() == 5
130+
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"]
131+
132+
95133
if __name__ == "__main__":
96134
test_unitsaggregationsorting()

src/spikeinterface/core/unitsaggregationsorting.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,21 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
3434
)
3535
unit_ids = list(renamed_unit_ids)
3636
else:
37-
unit_ids = list(np.arange(num_all_units))
37+
unit_ids_dtypes = [sort.get_unit_ids().dtype for sort in sorting_list]
38+
all_ids_are_same_type = np.unique(unit_ids_dtypes).size == 1
39+
all_units_ids_are_unique = False
40+
if all_ids_are_same_type:
41+
combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list])
42+
all_units_ids_are_unique = np.unique(combined_ids).size == num_all_units
43+
44+
if all_ids_are_same_type and all_units_ids_are_unique:
45+
unit_ids = combined_ids
46+
else:
47+
default_unit_ids = [str(i) for i in range(num_all_units)]
48+
if all_ids_are_same_type and np.issubdtype(unit_ids_dtypes[0], np.integer):
49+
unit_ids = np.arange(num_all_units, dtype=np.uint64)
50+
else:
51+
unit_ids = default_unit_ids
3852

3953
# unit map maps unit ids that are used to get spike trains
4054
u_id = 0

0 commit comments

Comments
 (0)