@@ -91,16 +91,28 @@ def _stat(self):
9191 if d .ndim == 2 and d .shape [1 ] == 1 else d )
9292
9393 # Read the Cluster Groups
94- if (self ._ks_dir / 'cluster_groups.csv' ).exists ():
95- df = pd .read_csv (self ._ks_dir / 'cluster_groups.csv' , delimiter = '\t ' )
96- self ._data ['cluster_groups' ] = np .array (df ['group' ].values )
97- self ._data ['cluster_ids' ] = np .array (df ['cluster_id' ].values )
98- elif (self ._ks_dir / 'cluster_KSLabel.tsv' ).exists ():
99- df = pd .read_csv (self ._ks_dir / 'cluster_KSLabel.tsv' , sep = "\t " , header = 0 )
100- self ._data ['cluster_groups' ] = np .array (df ['KSLabel' ].values )
101- self ._data ['cluster_ids' ] = np .array (df ['cluster_id' ].values )
94+ for cluster_pattern , cluster_col_name in zip (['cluster_groups.*' , 'cluster_KSLabel.*' ],
95+ ['group' , 'KSLabel' ]):
96+ try :
97+ cluster_file = next (self ._ks_dir .glob (cluster_pattern ))
98+ cluster_file_suffix = cluster_file .suffix
99+ assert cluster_file_suffix in ('.csv' , '.tsv' , '.xlsx' )
100+ break
101+ except StopIteration :
102+ pass
102103 else :
103- raise FileNotFoundError ('Neither cluster_groups.csv nor cluster_KSLabel.tsv found!' )
104+ raise FileNotFoundError (
105+ 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' )
106+
107+ if cluster_file_suffix == '.tsv' :
108+ df = pd .read_csv (cluster_file , sep = '\t ' , header = 0 )
109+ elif cluster_file_suffix == '.xlsx' :
110+ df = pd .read_excel (cluster_file , engine = 'openpyxl' )
111+ else :
112+ df = pd .read_csv (cluster_file , delimiter = '\t ' )
113+
114+ self ._data ['cluster_groups' ] = np .array (df [cluster_col_name ].values )
115+ self ._data ['cluster_ids' ] = np .array (df ['cluster_id' ].values )
104116
105117 def get_best_channel (self , unit ):
106118 template_idx = self .data ['spike_templates' ][
0 commit comments