@@ -558,15 +558,18 @@ def make(self, key):
558558
559559 unit_waveforms , unit_peak_waveforms = [], []
560560 if is_qc :
561- unit_wfs = np .load (ks_dir / 'mean_waveforms.npy' ) # unit x channel x sample
562- for unit_no , unit_wf in zip (ks .data ['cluster_ids' ], unit_wfs ):
561+ unit_waveforms = np .load (ks_dir / 'mean_waveforms.npy' ) # unit x channel x sample
562+ for unit_no , unit_waveform in zip (ks .data ['cluster_ids' ], unit_waveforms ):
563563 if unit_no in units :
564- for chn , chn_wf in zip (ks .data ['channel_map' ], unit_wf ):
565- unit_waveforms .append ({** units [unit_no ], ** channel2electrodes [chn ],
566- 'waveform_mean' : chn_wf })
567- if channel2electrodes [chn ]['electrode' ] == units [unit_no ]['electrode' ]:
568- unit_peak_waveforms .append ({** units [unit_no ],
569- 'peak_chn_waveform_mean' : chn_wf })
564+ for channel , channel_waveform in zip (ks .data ['channel_map' ],
565+ unit_waveform ):
566+ unit_waveforms .append ({
567+ ** units [unit_no ], ** channel2electrodes [channel ],
568+ 'waveform_mean' : channel_waveform })
569+ if channel2electrodes [channel ]['electrode' ] == units [unit_no ]['electrode' ]:
570+ unit_peak_waveforms .append ({
571+ ** units [unit_no ],
572+ 'peak_chn_waveform_mean' : channel_waveform })
570573 else :
571574 if acq_software == 'SpikeGLX' :
572575 npx_meta_fp = root_dir / (EphysRecording .EphysFile & key
@@ -578,16 +581,18 @@ def make(self, key):
578581 npx_recording = loaded_oe .probes [probe_sn ]
579582
580583 for unit_dict in units .values ():
581- spks = unit_dict ['spike_times' ]
582- wfs = npx_recording .extract_spike_waveforms (spks , ks .data ['channel_map' ]) # (sample x channel x spike)
583- wfs = wfs .transpose ((1 , 2 , 0 )) # (channel x spike x sample)
584- for chn , chn_wf in zip (ks .data ['channel_map' ], wfs ):
585- unit_waveforms .append ({** unit_dict , ** channel2electrodes [chn ],
586- 'waveform_mean' : chn_wf .mean (axis = 0 ),
587- 'waveforms' : chn_wf })
588- if channel2electrodes [chn ]['electrode' ] == unit_dict ['electrode' ]:
584+ spikes = unit_dict ['spike_times' ]
585+ waveforms = npx_recording .extract_spike_waveforms (
586+ spikes , ks .data ['channel_map' ]) # (sample x channel x spike)
587+ waveforms = waveforms .transpose ((1 , 2 , 0 )) # (channel x spike x sample)
588+ for channel , channel_waveform in zip (ks .data ['channel_map' ], waveforms ):
589+ unit_waveforms .append ({** unit_dict , ** channel2electrodes [channel ],
590+ 'waveform_mean' : channel_waveform .mean (axis = 0 ),
591+ 'waveforms' : channel_waveform })
592+ if channel2electrodes [channel ]['electrode' ] == unit_dict ['electrode' ]:
589593 unit_peak_waveforms .append ({
590- ** unit_dict , 'peak_chn_waveform_mean' : chn_wf .mean (axis = 0 )})
594+ ** unit_dict ,
595+ 'peak_chn_waveform_mean' : channel_waveform .mean (axis = 0 )})
591596
592597 self .insert (unit_peak_waveforms , ignore_extra_fields = True )
593598 self .Electrode .insert (unit_waveforms , ignore_extra_fields = True )
0 commit comments