@@ -119,38 +119,50 @@ def __init__(
119119
120120 if unit_ids is None :
121121 unit_ids = sorting_analyzer_or_templates .unit_ids
122- if channel_ids is None :
123- channel_ids = sorting_analyzer_or_templates .channel_ids
124122 if unit_colors is None :
125123 unit_colors = get_unit_colors (sorting_analyzer_or_templates )
126124
127- channel_indices = [list (sorting_analyzer_or_templates .channel_ids ).index (ch ) for ch in channel_ids ]
128- channel_locations = sorting_analyzer_or_templates .get_channel_locations ()[channel_indices ]
129- extra_sparsity = False
130- if sorting_analyzer_or_templates .sparsity is not None :
131- if sparsity is None :
132- sparsity = sorting_analyzer_or_templates .sparsity
133- else :
134- # assert provided sparsity is a subset of waveform sparsity
135- combined_mask = np .logical_or (sorting_analyzer_or_templates .sparsity .mask , sparsity .mask )
136- assert np .all (np .sum (combined_mask , 1 ) - np .sum (sorting_analyzer_or_templates .sparsity .mask , 1 ) == 0 ), (
137- "The provided 'sparsity' needs to include only the sparse channels "
138- "used to extract waveforms (for example, by using a smaller 'radius_um')."
139- )
140- extra_sparsity = True
141- else :
142- if sparsity is None :
143- # in this case, we construct a dense sparsity
144- unit_id_to_channel_ids = {
145- u : sorting_analyzer_or_templates .channel_ids for u in sorting_analyzer_or_templates .unit_ids
146- }
147- sparsity = ChannelSparsity .from_unit_id_to_channel_ids (
148- unit_id_to_channel_ids = unit_id_to_channel_ids ,
149- unit_ids = sorting_analyzer_or_templates .unit_ids ,
150- channel_ids = sorting_analyzer_or_templates .channel_ids ,
151- )
152- else :
153- assert isinstance (sparsity , ChannelSparsity ), "'sparsity' should be a ChannelSparsity object!"
125+ channel_locations = sorting_analyzer_or_templates .get_channel_locations ()
126+ extra_sparsity = None
127+ # handle sparsity
128+ sparsity_mismatch_warning = (
129+ "The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
130+ "These extra channels will be plotted as flat lines."
131+ )
132+ analyzer_sparsity = sorting_analyzer_or_templates .sparsity
133+ if channel_ids is not None :
134+ assert sparsity is None , "If 'channel_ids' is provided, 'sparsity' should be None!"
135+ channel_mask = np .tile (
136+ np .isin (sorting_analyzer_or_templates .channel_ids , channel_ids ),
137+ (len (sorting_analyzer_or_templates .unit_ids ), 1 ),
138+ )
139+ extra_sparsity = ChannelSparsity (
140+ mask = channel_mask ,
141+ channel_ids = sorting_analyzer_or_templates .channel_ids ,
142+ unit_ids = sorting_analyzer_or_templates .unit_ids ,
143+ )
144+ elif sparsity is not None :
145+ extra_sparsity = sparsity
146+
147+ if channel_ids is None :
148+ channel_ids = sorting_analyzer_or_templates .channel_ids
149+
150+ # assert provided sparsity is a subset of waveform sparsity
151+ if extra_sparsity is not None and analyzer_sparsity is not None :
152+ combined_mask = np .logical_or (analyzer_sparsity .mask , extra_sparsity .mask )
153+ if not np .all (np .sum (combined_mask , 1 ) - np .sum (analyzer_sparsity .mask , 1 ) == 0 ):
154+ warn (sparsity_mismatch_warning )
155+
156+ final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
157+ if final_sparsity is None :
158+ final_sparsity = ChannelSparsity (
159+ mask = np .ones (
160+ (len (sorting_analyzer_or_templates .unit_ids ), len (sorting_analyzer_or_templates .channel_ids )),
161+ dtype = bool ,
162+ ),
163+ unit_ids = sorting_analyzer_or_templates .unit_ids ,
164+ channel_ids = sorting_analyzer_or_templates .channel_ids ,
165+ )
154166
155167 # get templates
156168 if isinstance (sorting_analyzer_or_templates , Templates ):
@@ -174,42 +186,23 @@ def __init__(
174186 templates_percentile_shading = None
175187 templates_shading = self ._get_template_shadings (unit_ids , templates_percentile_shading )
176188
177- wfs_by_ids = {}
178189 if plot_waveforms :
179190 # this must be a sorting_analyzer
180191 wf_ext = sorting_analyzer_or_templates .get_extension ("waveforms" )
181192 if wf_ext is None :
182193 raise ValueError ("plot_waveforms() needs the extension 'waveforms'" )
183- for unit_id in unit_ids :
184- unit_index = list (sorting_analyzer_or_templates .unit_ids ).index (unit_id )
185- if not extra_sparsity :
186- if sorting_analyzer_or_templates .is_sparse ():
187- # wfs = we.get_waveforms(unit_id)
188- wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
189- else :
190- # wfs = we.get_waveforms(unit_id, sparsity=sparsity)
191- wfs = wf_ext .get_waveforms_one_unit (unit_id )
192- wfs = wfs [:, :, sparsity .mask [unit_index ]]
193- else :
194- # in this case we have to slice the waveform sparsity based on the extra sparsity
195- # first get the sparse waveforms
196- # wfs = we.get_waveforms(unit_id)
197- wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
198- # find additional slice to apply to sparse waveforms
199- (wfs_sparse_indices ,) = np .nonzero (sorting_analyzer_or_templates .sparsity .mask [unit_index ])
200- (extra_sparse_indices ,) = np .nonzero (sparsity .mask [unit_index ])
201- (extra_slice ,) = np .nonzero (np .isin (wfs_sparse_indices , extra_sparse_indices ))
202- # apply extra sparsity
203- wfs = wfs [:, :, extra_slice ]
204- wfs_by_ids [unit_id ] = wfs
194+ wfs_by_ids = self ._get_wfs_by_ids (sorting_analyzer_or_templates , unit_ids , extra_sparsity = extra_sparsity )
195+ else :
196+ wfs_by_ids = None
205197
206198 plot_data = dict (
207199 sorting_analyzer_or_templates = sorting_analyzer_or_templates ,
208200 sampling_frequency = sorting_analyzer_or_templates .sampling_frequency ,
209201 nbefore = nbefore ,
210202 unit_ids = unit_ids ,
211203 channel_ids = channel_ids ,
212- sparsity = sparsity ,
204+ final_sparsity = final_sparsity ,
205+ extra_sparsity = extra_sparsity ,
213206 unit_colors = unit_colors ,
214207 channel_locations = channel_locations ,
215208 scale = scale ,
@@ -270,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
270263 ax = self .axes .flatten ()[i ]
271264 color = dp .unit_colors [unit_id ]
272265
273- chan_inds = dp .sparsity .unit_id_to_channel_indices [unit_id ]
266+ chan_inds = dp .final_sparsity .unit_id_to_channel_indices [unit_id ]
274267 xvectors_flat = xvectors [:, chan_inds ].T .flatten ()
275268
276269 # plot waveforms
@@ -502,6 +495,32 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
502495 if backend_kwargs ["display" ]:
503496 display (self .widget )
504497
498+ def _get_wfs_by_ids (self , sorting_analyzer , unit_ids , extra_sparsity ):
499+ wfs_by_ids = {}
500+ wf_ext = sorting_analyzer .get_extension ("waveforms" )
501+ for unit_id in unit_ids :
502+ unit_index = list (sorting_analyzer .unit_ids ).index (unit_id )
503+ if extra_sparsity is None :
504+ wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
505+ else :
506+ # in this case we have to construct waveforms based on the extra sparsity and add the
507+ # sparse waveforms on the valid channels
508+ if sorting_analyzer .is_sparse ():
509+ original_mask = sorting_analyzer .sparsity .mask [unit_index ]
510+ else :
511+ original_mask = np .ones (len (sorting_analyzer .channel_ids ), dtype = bool )
512+ wfs_orig = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
513+ wfs = np .zeros (
514+ (wfs_orig .shape [0 ], wfs_orig .shape [1 ], extra_sparsity .mask [unit_index ].sum ()), dtype = wfs_orig .dtype
515+ )
516+ # fill in the existing waveforms channels
517+ valid_wfs_indices = extra_sparsity .mask [unit_index ][original_mask ]
518+ valid_extra_indices = original_mask [extra_sparsity .mask [unit_index ]]
519+ wfs [:, :, valid_extra_indices ] = wfs_orig [:, :, valid_wfs_indices ]
520+
521+ wfs_by_ids [unit_id ] = wfs
522+ return wfs_by_ids
523+
505524 def _get_template_shadings (self , unit_ids , templates_percentile_shading ):
506525 templates = self .templates_ext .get_templates (unit_ids = unit_ids , operator = "average" )
507526
@@ -538,6 +557,8 @@ def _update_plot(self, change):
538557 hide_axis = self .hide_axis_button .value
539558 do_shading = self .template_shading_button .value
540559
560+ data_plot = self .next_data_plot
561+
541562 if self .sorting_analyzer is not None :
542563 templates = self .templates_ext .get_templates (unit_ids = unit_ids , operator = "average" )
543564 templates_shadings = self ._get_template_shadings (unit_ids , data_plot ["templates_percentile_shading" ])
@@ -549,7 +570,6 @@ def _update_plot(self, change):
549570 channel_locations = self .templates .get_channel_locations ()
550571
551572 # matplotlib next_data_plot dict update at each call
552- data_plot = self .next_data_plot
553573 data_plot ["unit_ids" ] = unit_ids
554574 data_plot ["templates" ] = templates
555575 data_plot ["templates_shading" ] = templates_shadings
@@ -564,10 +584,10 @@ def _update_plot(self, change):
564584 data_plot ["scalebar" ] = self .scalebar .value
565585
566586 if data_plot ["plot_waveforms" ]:
567- wf_ext = self .sorting_analyzer . get_extension ( "waveforms" )
568- data_plot ["wfs_by_ids" ] = {
569- unit_id : wf_ext . get_waveforms_one_unit ( unit_id , force_dense = False ) for unit_id in unit_ids
570- }
587+ wfs_by_ids = self ._get_wfs_by_ids (
588+ self . sorting_analyzer , unit_ids , extra_sparsity = data_plot ["extra_sparsity" ]
589+ )
590+ data_plot [ "wfs_by_ids" ] = wfs_by_ids
571591
572592 # TODO option for plot_legend
573593 backend_kwargs = {}
@@ -611,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):
611631
612632 # TODO this could be done with probeinterface plotting plotting tools!!
613633 for unit in unit_ids :
614- channel_inds = self .data_plot ["sparsity " ].unit_id_to_channel_indices [unit ]
634+ channel_inds = self .data_plot ["final_sparsity " ].unit_id_to_channel_indices [unit ]
615635 ax .plot (
616636 channel_locations [channel_inds , 0 ],
617637 channel_locations [channel_inds , 1 ],
0 commit comments