1212
1313class Kilosort :
1414
15- ks_files = [
15+ kilosort_files = [
1616 'params.py' ,
1717 'amplitudes.npy' ,
1818 'channel_map.npy' ,
@@ -36,18 +36,18 @@ class Kilosort:
3636 ]
3737
3838 # keys to self.files, .data are file name e.g. self.data['params'], etc.
39- ks_keys = [path .splitext (ks_file )[0 ] for ks_file in ks_files ]
39+ kilosort_keys = [path .splitext (kilosort_file )[0 ] for kilosort_file in kilosort_files ]
4040
41- def __init__ (self , ks_dir ):
42- self ._ks_dir = pathlib .Path (ks_dir )
41+ def __init__ (self , kilosort_dir ):
42+ self ._kilosort_dir = pathlib .Path (kilosort_dir )
4343 self ._files = {}
4444 self ._data = None
4545 self ._clusters = None
4646
47- params_filepath = ks_dir / 'params.py'
47+ params_filepath = kilosort_dir / 'params.py'
4848
4949 if not params_filepath .exists ():
50- raise FileNotFoundError (f'No Kilosort output found in: { ks_dir } ' )
50+ raise FileNotFoundError (f'No Kilosort output found in: { kilosort_dir } ' )
5151
5252 self ._info = {'time_created' : datetime .fromtimestamp (params_filepath .stat ().st_ctime ),
5353 'time_modified' : datetime .fromtimestamp (params_filepath .stat ().st_mtime )}
@@ -64,42 +64,44 @@ def info(self):
6464
6565 def _stat (self ):
6666 self ._data = {}
67- for ks_filename in Kilosort .ks_files :
68- ks_filepath = self ._ks_dir / ks_filename
67+ for kilosort_filename in Kilosort .kilosort_files :
68+ kilosort_filepath = self ._kilosort_dir / kilosort_filename
6969
70- if not ks_filepath .exists ():
71- log .debug ('skipping {} - does not exist' .format (ks_filepath ))
70+ if not kilosort_filepath .exists ():
71+ log .debug ('skipping {} - does not exist' .format (kilosort_filepath ))
7272 continue
7373
74- base , ext = path .splitext (ks_filename )
75- self ._files [base ] = ks_filepath
74+ base , ext = path .splitext (kilosort_filename )
75+ self ._files [base ] = kilosort_filepath
7676
77- if ks_filename == 'params.py' :
78- log .debug ('loading params.py {}' .format (ks_filepath ))
77+ if kilosort_filename == 'params.py' :
78+ log .debug ('loading params.py {}' .format (kilosort_filepath ))
7979 # params.py is a 'key = val' file
8080 params = {}
81- for line in open (ks_filepath , 'r' ).readlines ():
81+ for line in open (kilosort_filepath , 'r' ).readlines ():
8282 k , v = line .strip ('\n ' ).split ('=' )
8383 params [k .strip ()] = convert_to_number (v .strip ())
8484 log .debug ('params: {}' .format (params ))
8585 self ._data [base ] = params
8686
8787 if ext == '.npy' :
88- log .debug ('loading npy {}' .format (ks_filepath ))
89- d = np .load (ks_filepath , mmap_mode = 'r' , allow_pickle = False , fix_imports = False )
88+ log .debug ('loading npy {}' .format (kilosort_filepath ))
89+ d = np .load (kilosort_filepath , mmap_mode = 'r' ,
90+ allow_pickle = False , fix_imports = False )
9091 self ._data [base ] = (np .reshape (d , d .shape [0 ])
9192 if d .ndim == 2 and d .shape [1 ] == 1 else d )
9293
9394 # Read the Cluster Groups
9495 for cluster_pattern , cluster_col_name in zip (['cluster_groups.*' , 'cluster_KSLabel.*' ],
9596 ['group' , 'KSLabel' ]):
9697 try :
97- cluster_file = next (self ._ks_dir .glob (cluster_pattern ))
98- cluster_file_suffix = cluster_file .suffix
99- assert cluster_file_suffix in ('.csv' , '.tsv' , '.xlsx' )
100- break
98+ cluster_file = next (self ._kilosort_dir .glob (cluster_pattern ))
10199 except StopIteration :
102100 pass
101+
102+ cluster_file_suffix = cluster_file .suffix
103+ assert cluster_file_suffix in ('.csv' , '.tsv' , '.xlsx' )
104+ break
103105 else :
104106 raise FileNotFoundError (
105107 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' )
@@ -118,7 +120,7 @@ def get_best_channel(self, unit):
118120 template_idx = self .data ['spike_templates' ][
119121 np .where (self .data ['spike_clusters' ] == unit )[0 ][0 ]]
120122 channel_templates = self .data ['templates' ][template_idx , :, :]
121- max_channel_idx = np .abs (np . abs ( channel_templates ).max (axis = 0 ) ).argmax ()
123+ max_channel_idx = np .abs (channel_templates ).max (axis = 0 ).argmax ()
122124 max_channel = self .data ['channel_map' ][max_channel_idx ]
123125
124126 return max_channel , max_channel_idx
@@ -174,12 +176,10 @@ def extract_clustering_info(cluster_output_dir):
174176
175177 # ---- Quality control? ----
176178 metric_filepath = cluster_output_dir / 'metrics.csv'
177- if metric_filepath .exists ():
178- is_qc = True
179+ is_qc = metric_filepath .exists ()
180+ if is_qc :
179181 if creation_time is None :
180182 creation_time = datetime .fromtimestamp (metric_filepath .stat ().st_ctime )
181- else :
182- is_qc = False
183183
184184 if creation_time is None :
185185 spiketimes_filepath = next (cluster_output_dir .glob ('spike_times.npy' ))
0 commit comments