1717from ..core import get_random_data_chunks , compute_sparsity
1818from ..core .template_tools import get_template_extremum_channel
1919
20-
2120_possible_pc_metric_names = [
2221 "isolation_distance" ,
2322 "l_ratio" ,
@@ -90,7 +89,7 @@ def compute_pc_metrics(
9089 sorting = sorting_analyzer .sorting
9190
9291 if metric_names is None :
93- metric_names = _possible_pc_metric_names
92+ metric_names = _possible_pc_metric_names . copy ()
9493 if qm_params is None :
9594 qm_params = _default_params
9695
@@ -110,8 +109,13 @@ def compute_pc_metrics(
110109 if "nn_isolation" in metric_names :
111110 pc_metrics ["nn_unit_id" ] = {}
112111
112+ possible_nn_metrics = ["nn_isolation" , "nn_noise_overlap" ]
113+
114+ nn_metrics = list (set (metric_names ).intersection (possible_nn_metrics ))
115+ non_nn_metrics = list (set (metric_names ).difference (possible_nn_metrics ))
116+
113117 # Compute nspikes and firing rate outside of main loop for speed
114- if any ([ n in metric_names for n in [ "nn_isolation" , "nn_noise_overlap" ]]) :
118+ if nn_metrics :
115119 n_spikes_all_units = compute_num_spikes (sorting_analyzer , unit_ids = unit_ids )
116120 fr_all_units = compute_firing_rates (sorting_analyzer , unit_ids = unit_ids )
117121 else :
@@ -120,9 +124,6 @@ def compute_pc_metrics(
120124
121125 run_in_parallel = n_jobs > 1
122126
123- if run_in_parallel :
124- parallel_functions = []
125-
126127 # this get dense projection for selected unit_ids
127128 dense_projections , spike_unit_indices = pca_ext .get_some_projections (channel_ids = None , unit_ids = unit_ids )
128129 all_labels = sorting .unit_ids [spike_unit_indices ]
@@ -146,7 +147,7 @@ def compute_pc_metrics(
146147 func_args = (
147148 pcs_flat ,
148149 labels ,
149- metric_names ,
150+ non_nn_metrics ,
150151 unit_id ,
151152 unit_ids ,
152153 qm_params ,
@@ -156,16 +157,16 @@ def compute_pc_metrics(
156157 )
157158 items .append (func_args )
158159
159- if not run_in_parallel :
160+ if not run_in_parallel and non_nn_metrics :
160161 units_loop = enumerate (unit_ids )
161162 if progress_bar :
162- units_loop = tqdm (units_loop , desc = "calculate_pc_metrics " , total = len (unit_ids ))
163+ units_loop = tqdm (units_loop , desc = "calculate pc_metrics " , total = len (unit_ids ))
163164
164165 for unit_ind , unit_id in units_loop :
165166 pca_metrics_unit = pca_metrics_one_unit (items [unit_ind ])
166167 for metric_name , metric in pca_metrics_unit .items ():
167168 pc_metrics [metric_name ][unit_id ] = metric
168- else :
169+ elif run_in_parallel and non_nn_metrics :
169170 with ProcessPoolExecutor (n_jobs ) as executor :
170171 results = executor .map (pca_metrics_one_unit , items )
171172 if progress_bar :
@@ -176,6 +177,37 @@ def compute_pc_metrics(
176177 for metric_name , metric in pca_metrics_unit .items ():
177178 pc_metrics [metric_name ][unit_id ] = metric
178179
180+ for metric_name in nn_metrics :
181+ units_loop = enumerate (unit_ids )
182+ if progress_bar :
183+ units_loop = tqdm (units_loop , desc = f"calculate { metric_name } metric" , total = len (unit_ids ))
184+
185+ func = _nn_metric_name_to_func [metric_name ]
186+ metric_params = qm_params [metric_name ] if metric_name in qm_params else {}
187+
188+ for _ , unit_id in units_loop :
189+ try :
190+ res = func (
191+ sorting_analyzer ,
192+ unit_id ,
193+ seed = seed ,
194+ n_spikes_all_units = n_spikes_all_units ,
195+ fr_all_units = fr_all_units ,
196+ ** metric_params ,
197+ )
198+ except :
199+ if metric_name == "nn_isolation" :
200+ res = (np .nan , np .nan )
201+ elif metric_name == "nn_noise_overlap" :
202+ res = np .nan
203+
204+ if metric_name == "nn_isolation" :
205+ nn_isolation , nn_unit_id = res
206+ pc_metrics ["nn_isolation" ][unit_id ] = nn_isolation
207+ pc_metrics ["nn_unit_id" ][unit_id ] = nn_unit_id
208+ elif metric_name == "nn_noise_overlap" :
209+ pc_metrics ["nn_noise_overlap" ][unit_id ] = res
210+
179211 return pc_metrics
180212
181213
@@ -677,6 +709,14 @@ def nearest_neighbors_noise_overlap(
677709 templates_ext = sorting_analyzer .get_extension ("templates" )
678710 assert templates_ext is not None , "nearest_neighbors_isolation() need extension 'templates'"
679711
712+ try :
713+ sorting_analyzer .get_extension ("templates" ).get_data (operator = "median" )
714+ except KeyError :
715+ warnings .warn (
716+ "nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator."
717+ "You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes."
718+ )
719+
680720 if n_spikes_all_units is None :
681721 n_spikes_all_units = compute_num_spikes (sorting_analyzer )
682722 if fr_all_units is None :
@@ -955,11 +995,13 @@ def pca_metrics_one_unit(args):
955995 pc_metrics = {}
956996 # metrics
957997 if "isolation_distance" in metric_names or "l_ratio" in metric_names :
998+
958999 try :
9591000 isolation_distance , l_ratio = mahalanobis_metrics (pcs_flat , labels , unit_id )
9601001 except :
9611002 isolation_distance = np .nan
9621003 l_ratio = np .nan
1004+
9631005 if "isolation_distance" in metric_names :
9641006 pc_metrics ["isolation_distance" ] = isolation_distance
9651007 if "l_ratio" in metric_names :
@@ -973,6 +1015,7 @@ def pca_metrics_one_unit(args):
9731015 d_prime = lda_metrics (pcs_flat , labels , unit_id )
9741016 except :
9751017 d_prime = np .nan
1018+
9761019 pc_metrics ["d_prime" ] = d_prime
9771020
9781021 if "nearest_neighbor" in metric_names :
@@ -986,36 +1029,6 @@ def pca_metrics_one_unit(args):
9861029 pc_metrics ["nn_hit_rate" ] = nn_hit_rate
9871030 pc_metrics ["nn_miss_rate" ] = nn_miss_rate
9881031
989- if "nn_isolation" in metric_names :
990- try :
991- nn_isolation , nn_unit_id = nearest_neighbors_isolation (
992- we ,
993- unit_id ,
994- seed = seed ,
995- n_spikes_all_units = n_spikes_all_units ,
996- fr_all_units = fr_all_units ,
997- ** qm_params ["nn_isolation" ],
998- )
999- except :
1000- nn_isolation = np .nan
1001- nn_unit_id = np .nan
1002- pc_metrics ["nn_isolation" ] = nn_isolation
1003- pc_metrics ["nn_unit_id" ] = nn_unit_id
1004-
1005- if "nn_noise_overlap" in metric_names :
1006- try :
1007- nn_noise_overlap = nearest_neighbors_noise_overlap (
1008- we ,
1009- unit_id ,
1010- n_spikes_all_units = n_spikes_all_units ,
1011- fr_all_units = fr_all_units ,
1012- seed = seed ,
1013- ** qm_params ["nn_noise_overlap" ],
1014- )
1015- except :
1016- nn_noise_overlap = np .nan
1017- pc_metrics ["nn_noise_overlap" ] = nn_noise_overlap
1018-
10191032 if "silhouette" in metric_names :
10201033 silhouette_method = qm_params ["silhouette" ]["method" ]
10211034 if "simplified" in silhouette_method :
@@ -1032,3 +1045,9 @@ def pca_metrics_one_unit(args):
10321045 pc_metrics ["silhouette_full" ] = unit_silhouette_score
10331046
10341047 return pc_metrics
1048+
1049+
1050+ _nn_metric_name_to_func = {
1051+ "nn_isolation" : nearest_neighbors_isolation ,
1052+ "nn_noise_overlap" : nearest_neighbors_noise_overlap ,
1053+ }
0 commit comments