1313class Kilosort :
1414
1515 _kilosort_core_files = [
16- ' params.py' ,
17- ' amplitudes.npy' ,
18- ' channel_map.npy' ,
19- ' channel_positions.npy' ,
20- ' pc_features.npy' ,
21- ' pc_feature_ind.npy' ,
22- ' similar_templates.npy' ,
23- ' spike_templates.npy' ,
24- ' spike_times.npy' ,
25- ' template_features.npy' ,
26- ' template_feature_ind.npy' ,
27- ' templates.npy' ,
28- ' templates_ind.npy' ,
29- ' whitening_mat.npy' ,
30- ' whitening_mat_inv.npy' ,
31- ' spike_clusters.npy'
16+ " params.py" ,
17+ " amplitudes.npy" ,
18+ " channel_map.npy" ,
19+ " channel_positions.npy" ,
20+ " pc_features.npy" ,
21+ " pc_feature_ind.npy" ,
22+ " similar_templates.npy" ,
23+ " spike_templates.npy" ,
24+ " spike_times.npy" ,
25+ " template_features.npy" ,
26+ " template_feature_ind.npy" ,
27+ " templates.npy" ,
28+ " templates_ind.npy" ,
29+ " whitening_mat.npy" ,
30+ " whitening_mat_inv.npy" ,
31+ " spike_clusters.npy" ,
3232 ]
3333
3434 _kilosort_additional_files = [
35- ' spike_times_sec.npy' ,
36- ' spike_times_sec_adj.npy' ,
37- ' cluster_groups.csv' ,
38- ' cluster_KSLabel.tsv'
35+ " spike_times_sec.npy" ,
36+ " spike_times_sec_adj.npy" ,
37+ " cluster_groups.csv" ,
38+ " cluster_KSLabel.tsv" ,
3939 ]
4040
4141 kilosort_files = _kilosort_core_files + _kilosort_additional_files
@@ -48,9 +48,11 @@ def __init__(self, kilosort_dir):
4848
4949 self .validate ()
5050
51- params_filepath = kilosort_dir / 'params.py'
52- self ._info = {'time_created' : datetime .fromtimestamp (params_filepath .stat ().st_ctime ),
53- 'time_modified' : datetime .fromtimestamp (params_filepath .stat ().st_mtime )}
51+ params_filepath = kilosort_dir / "params.py"
52+ self ._info = {
53+ "time_created" : datetime .fromtimestamp (params_filepath .stat ().st_ctime ),
54+ "time_modified" : datetime .fromtimestamp (params_filepath .stat ().st_mtime ),
55+ }
5456
5557 @property
5658 def data (self ):
@@ -72,136 +74,157 @@ def validate(self):
7274 if not full_path .exists ():
7375 missing_files .append (f )
7476 if missing_files :
75- raise FileNotFoundError (f'Kilosort files missing in ({ self ._kilosort_dir } ):'
76- f' { missing_files } ' )
77+ raise FileNotFoundError (
78+ f"Kilosort files missing in ({ self ._kilosort_dir } ):" f" { missing_files } "
79+ )
7780
7881 def _load (self ):
7982 self ._data = {}
8083 for kilosort_filename in Kilosort .kilosort_files :
8184 kilosort_filepath = self ._kilosort_dir / kilosort_filename
8285
8386 if not kilosort_filepath .exists ():
84- log .debug (' skipping {} - does not exist' .format (kilosort_filepath ))
87+ log .debug (" skipping {} - does not exist" .format (kilosort_filepath ))
8588 continue
8689
8790 base , ext = path .splitext (kilosort_filename )
8891 self ._files [base ] = kilosort_filepath
8992
90- if kilosort_filename == ' params.py' :
91- log .debug (' loading params.py {}' .format (kilosort_filepath ))
93+ if kilosort_filename == " params.py" :
94+ log .debug (" loading params.py {}" .format (kilosort_filepath ))
9295 # params.py is a 'key = val' file
9396 params = {}
94- for line in open (kilosort_filepath , 'r' ).readlines ():
95- k , v = line .strip (' \n ' ).split ('=' )
97+ for line in open (kilosort_filepath , "r" ).readlines ():
98+ k , v = line .strip (" \n " ).split ("=" )
9699 params [k .strip ()] = convert_to_number (v .strip ())
97- log .debug (' params: {}' .format (params ))
100+ log .debug (" params: {}" .format (params ))
98101 self ._data [base ] = params
99102
100- if ext == '.npy' :
101- log .debug ('loading npy {}' .format (kilosort_filepath ))
102- d = np .load (kilosort_filepath , mmap_mode = 'r' ,
103- allow_pickle = False , fix_imports = False )
104- self ._data [base ] = (np .reshape (d , d .shape [0 ])
105- if d .ndim == 2 and d .shape [1 ] == 1 else d )
103+ if ext == ".npy" :
104+ log .debug ("loading npy {}" .format (kilosort_filepath ))
105+ d = np .load (
106+ kilosort_filepath ,
107+ mmap_mode = "r" ,
108+ allow_pickle = False ,
109+ fix_imports = False ,
110+ )
111+ self ._data [base ] = (
112+ np .reshape (d , d .shape [0 ]) if d .ndim == 2 and d .shape [1 ] == 1 else d
113+ )
106114
107- self ._data [' channel_map' ] = self ._data [' channel_map' ].flatten ()
115+ self ._data [" channel_map" ] = self ._data [" channel_map" ].flatten ()
108116
109117 # Read the Cluster Groups
110- for cluster_pattern , cluster_col_name in zip (['cluster_group.*' , 'cluster_KSLabel.*' ],
111- ['group' , 'KSLabel' ]):
118+ for cluster_pattern , cluster_col_name in zip (
119+ ["cluster_group.*" , "cluster_KSLabel.*" ], ["group" , "KSLabel" ]
120+ ):
112121 try :
113122 cluster_file = next (self ._kilosort_dir .glob (cluster_pattern ))
114123 except StopIteration :
115124 pass
116125 else :
117126 cluster_file_suffix = cluster_file .suffix
118- assert cluster_file_suffix in (' .tsv' , ' .xlsx' )
127+ assert cluster_file_suffix in (" .tsv" , " .xlsx" )
119128 break
120129 else :
121130 raise FileNotFoundError (
122- 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' )
131+ 'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
132+ )
123133
124- if cluster_file_suffix == ' .tsv' :
125- df = pd .read_csv (cluster_file , sep = ' \t ' , header = 0 )
126- elif cluster_file_suffix == ' .xlsx' :
127- df = pd .read_excel (cluster_file , engine = ' openpyxl' )
134+ if cluster_file_suffix == " .tsv" :
135+ df = pd .read_csv (cluster_file , sep = " \t " , header = 0 )
136+ elif cluster_file_suffix == " .xlsx" :
137+ df = pd .read_excel (cluster_file , engine = " openpyxl" )
128138 else :
129- df = pd .read_csv (cluster_file , delimiter = ' \t ' )
139+ df = pd .read_csv (cluster_file , delimiter = " \t " )
130140
131- self ._data [' cluster_groups' ] = np .array (df [cluster_col_name ].values )
132- self ._data [' cluster_ids' ] = np .array (df [' cluster_id' ].values )
141+ self ._data [" cluster_groups" ] = np .array (df [cluster_col_name ].values )
142+ self ._data [" cluster_ids" ] = np .array (df [" cluster_id" ].values )
133143
134144 def get_best_channel (self , unit ):
135- template_idx = self .data ['spike_templates' ][
136- np .where (self .data ['spike_clusters' ] == unit )[0 ][0 ]]
137- channel_templates = self .data ['templates' ][template_idx , :, :]
145+ template_idx = self .data ["spike_templates" ][
146+ np .where (self .data ["spike_clusters" ] == unit )[0 ][0 ]
147+ ]
148+ channel_templates = self .data ["templates" ][template_idx , :, :]
138149 max_channel_idx = np .abs (channel_templates ).max (axis = 0 ).argmax ()
139- max_channel = self .data [' channel_map' ][max_channel_idx ]
150+ max_channel = self .data [" channel_map" ][max_channel_idx ]
140151
141152 return max_channel , max_channel_idx
142153
143154 def extract_spike_depths (self ):
144- """ Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m """
145-
146- if ' pc_features' in self .data :
147- ycoords = self .data [' channel_positions' ][:, 1 ]
148- pc_features = self .data [' pc_features' ][:, 0 , :] # 1st PC only
155+ """Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m"""
156+
157+ if " pc_features" in self .data :
158+ ycoords = self .data [" channel_positions" ][:, 1 ]
159+ pc_features = self .data [" pc_features" ][:, 0 , :] # 1st PC only
149160 pc_features = np .where (pc_features < 0 , 0 , pc_features )
150161
151162 # ---- compute center of mass of these features (spike depths) ----
152163
153164 # which channels for each spike?
154- spk_feature_ind = self .data ['pc_feature_ind' ][self .data ['spike_templates' ], :]
165+ spk_feature_ind = self .data ["pc_feature_ind" ][
166+ self .data ["spike_templates" ], :
167+ ]
155168 # ycoords of those channels?
156169 spk_feature_ycoord = ycoords [spk_feature_ind ]
157170 # center of mass is sum(coords.*features)/sum(features)
158- self ._data ['spike_depths' ] = (np .sum (spk_feature_ycoord * pc_features ** 2 , axis = 1 )
159- / np .sum (pc_features ** 2 , axis = 1 ))
171+ self ._data ["spike_depths" ] = np .sum (
172+ spk_feature_ycoord * pc_features ** 2 , axis = 1
173+ ) / np .sum (pc_features ** 2 , axis = 1 )
160174 else :
161- self ._data [' spike_depths' ] = None
175+ self ._data [" spike_depths" ] = None
162176
163177 # ---- extract spike sites ----
164- max_site_ind = np .argmax (np .abs (self .data [' templates' ]).max (axis = 1 ), axis = 1 )
165- spike_site_ind = max_site_ind [self .data [' spike_templates' ]]
166- self ._data [' spike_sites' ] = self .data [' channel_map' ][spike_site_ind ]
178+ max_site_ind = np .argmax (np .abs (self .data [" templates" ]).max (axis = 1 ), axis = 1 )
179+ spike_site_ind = max_site_ind [self .data [" spike_templates" ]]
180+ self ._data [" spike_sites" ] = self .data [" channel_map" ][spike_site_ind ]
167181
168182
169183def extract_clustering_info (cluster_output_dir ):
170184 creation_time = None
171185
172- phy_curation_indicators = ['Merge clusters' , 'Split cluster' , 'Change metadata_group' ]
186+ phy_curation_indicators = [
187+ "Merge clusters" ,
188+ "Split cluster" ,
189+ "Change metadata_group" ,
190+ ]
173191 # ---- Manual curation? ----
174- phylog_filepath = cluster_output_dir / ' phy.log'
192+ phylog_filepath = cluster_output_dir / " phy.log"
175193 if phylog_filepath .exists ():
176194 phylog = pd .read_fwf (phylog_filepath , colspecs = [(6 , 40 ), (41 , 250 )])
177- phylog .columns = ['meta' , 'detail' ]
178- curation_row = [bool (re .match ('|' .join (phy_curation_indicators ), str (s )))
179- for s in phylog .detail ]
195+ phylog .columns = ["meta" , "detail" ]
196+ curation_row = [
197+ bool (re .match ("|" .join (phy_curation_indicators ), str (s )))
198+ for s in phylog .detail
199+ ]
180200 is_curated = bool (np .any (curation_row ))
181201 if creation_time is None and is_curated :
182202 row_meta = phylog .meta [np .where (curation_row )[0 ].max ()]
183- datetime_str = re .search (' \d{2}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}' , row_meta )
203+ datetime_str = re .search (" \d{2}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}" , row_meta )
184204 if datetime_str :
185- creation_time = datetime .strptime (datetime_str .group (), '%Y-%m-%d %H:%M:%S' )
205+ creation_time = datetime .strptime (
206+ datetime_str .group (), "%Y-%m-%d %H:%M:%S"
207+ )
186208 else :
187209 creation_time = datetime .fromtimestamp (phylog_filepath .stat ().st_ctime )
188- time_str = re .search (' \d{2}:\d{2}:\d{2}' , row_meta )
210+ time_str = re .search (" \d{2}:\d{2}:\d{2}" , row_meta )
189211 if time_str :
190212 creation_time = datetime .combine (
191213 creation_time .date (),
192- datetime .strptime (time_str .group (), '%H:%M:%S' ).time ())
214+ datetime .strptime (time_str .group (), "%H:%M:%S" ).time (),
215+ )
193216 else :
194217 is_curated = False
195218
196219 # ---- Quality control? ----
197- metric_filepath = cluster_output_dir / ' metrics.csv'
220+ metric_filepath = cluster_output_dir / " metrics.csv"
198221 is_qc = metric_filepath .exists ()
199222 if is_qc :
200223 if creation_time is None :
201224 creation_time = datetime .fromtimestamp (metric_filepath .stat ().st_ctime )
202225
203226 if creation_time is None :
204- spiketimes_filepath = next (cluster_output_dir .glob (' spike_times.npy' ))
227+ spiketimes_filepath = next (cluster_output_dir .glob (" spike_times.npy" ))
205228 creation_time = datetime .fromtimestamp (spiketimes_filepath .stat ().st_ctime )
206229
207230 return creation_time , is_curated , is_qc
0 commit comments