|
5 | 5 |
|
6 | 6 | from spikeinterface.core import NpzSortingExtractor |
7 | 7 | from spikeinterface.core import create_sorting_npz |
| 8 | +from spikeinterface.core import generate_sorting |
8 | 9 |
|
9 | 10 |
|
10 | 11 | def test_unitsaggregationsorting(create_cache_folder): |
@@ -33,10 +34,12 @@ def test_unitsaggregationsorting(create_cache_folder): |
33 | 34 | spiketrain1_1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=seg) |
34 | 35 | spiketrains2_0 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=seg) |
35 | 36 | spiketrains3_2 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=seg) |
36 | | - assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=seg)) |
37 | | - assert np.allclose(spiketrains2_0, sorting_agg.get_unit_spike_train(num_units + unit_ids[0], segment_index=seg)) |
| 37 | + assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(str(unit_ids[1]), segment_index=seg)) |
38 | 38 | assert np.allclose( |
39 | | - spiketrains3_2, sorting_agg.get_unit_spike_train(2 * num_units + unit_ids[2], segment_index=seg) |
| 39 | + spiketrains2_0, sorting_agg.get_unit_spike_train(str(num_units + unit_ids[0]), segment_index=seg) |
| 40 | + ) |
| 41 | + assert np.allclose( |
| 42 | + spiketrains3_2, sorting_agg.get_unit_spike_train(str(2 * num_units + unit_ids[2]), segment_index=seg) |
40 | 43 | ) |
41 | 44 |
|
42 | 45 | # test rename units |
@@ -92,5 +95,42 @@ def test_unitsaggregationsorting(create_cache_folder): |
92 | 95 | print(sorting_agg_prop.get_property("brain_area")) |
93 | 96 |
|
94 | 97 |
|
| 98 | +def test_unit_aggregation_preserve_ids(): |
| 99 | + |
| 100 | + sorting1 = generate_sorting(num_units=3) |
| 101 | + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 102 | + |
| 103 | + sorting2 = generate_sorting(num_units=3) |
| 104 | + sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"]) |
| 105 | + |
| 106 | + aggregated_sorting = aggregate_units([sorting1, sorting2]) |
| 107 | + assert aggregated_sorting.get_num_units() == 6 |
| 108 | + assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"] |
| 109 | + |
| 110 | + |
| 111 | +def test_unit_aggregation_does_not_preserve_ids_if_not_unique(): |
| 112 | + sorting1 = generate_sorting(num_units=3) |
| 113 | + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 114 | + |
| 115 | + sorting2 = generate_sorting(num_units=3) |
| 116 | + sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 117 | + |
| 118 | + aggregated_sorting = aggregate_units([sorting1, sorting2]) |
| 119 | + assert aggregated_sorting.get_num_units() == 6 |
| 120 | + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"] |
| 121 | + |
| 122 | + |
| 123 | +def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): |
| 124 | + sorting1 = generate_sorting(num_units=3) |
| 125 | + sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) |
| 126 | + |
| 127 | + sorting2 = generate_sorting(num_units=2) |
| 128 | + sorting2 = sorting2.rename_units(new_unit_ids=[1, 2]) |
| 129 | + |
| 130 | + aggregated_sorting = aggregate_units([sorting1, sorting2]) |
| 131 | + assert aggregated_sorting.get_num_units() == 5 |
| 132 | + assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] |
| 133 | + |
| 134 | + |
95 | 135 | if __name__ == "__main__": |
96 | 136 | test_unitsaggregationsorting() |
0 commit comments