|
5 | 5 |
|
6 | 6 | from spikeinterface.core import NpzSortingExtractor |
7 | 7 | from spikeinterface.core import create_sorting_npz |
| 8 | +from spikeinterface.core import generate_sorting |
8 | 9 |
|
9 | 10 |
|
10 | 11 | def test_unitsaggregationsorting(create_cache_folder): |
@@ -92,5 +93,42 @@ def test_unitsaggregationsorting(create_cache_folder): |
92 | 93 | print(sorting_agg_prop.get_property("brain_area")) |
93 | 94 |
|
94 | 95 |
|
| 96 | +def test_unit_aggregation_preserve_ids(): |
| 97 | + |
| 98 | + sorting1 = generate_sorting(num_units=3) |
| 99 | + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 100 | + |
| 101 | + sorting2 = generate_sorting(num_units=3) |
| 102 | + sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"]) |
| 103 | + |
| 104 | + aggregated_sorting = aggregate_units([sorting1, sorting2]) |
| 105 | + assert aggregated_sorting.get_num_units() == 6 |
| 106 | + assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"] |
| 107 | + |
| 108 | + |
| 109 | +def test_unit_aggregation_does_not_preserve_ids_if_not_unique(): |
| 110 | + sorting1 = generate_sorting(num_units=3) |
| 111 | + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 112 | + |
| 113 | + sorting2 = generate_sorting(num_units=3) |
| 114 | + sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 115 | + |
| 116 | + aggregated_sorting = aggregate_units([sorting1, sorting2]) |
| 117 | + assert aggregated_sorting.get_num_units() == 6 |
| 118 | + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"] |
| 119 | + |
| 120 | + |
| 121 | +def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): |
| 122 | + sorting1 = generate_sorting(num_units=3) |
| 123 | + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 124 | + |
| 125 | + sorting2 = generate_sorting(num_units=2) |
| 126 | + sorting2 = sorting2.rename_units(new_unit_ids=[1, 2]) |
| 127 | + |
| 128 | + aggregated_sorting = aggregate_units([sorting1, sorting2]) |
| 129 | + assert aggregated_sorting.get_num_units() == 5 |
| 130 | + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] |
| 131 | + |
| 132 | + |
95 | 133 | if __name__ == "__main__": |
96 | 134 | test_unitsaggregationsorting() |
0 commit comments