@@ -123,51 +123,47 @@ def __init__(
123123 unit_colors = get_unit_colors (sorting_analyzer_or_templates )
124124
125125 channel_locations = sorting_analyzer_or_templates .get_channel_locations ()
126- extra_sparsity = False
126+ extra_sparsity = None
127127 # handle sparsity
128128 sparsity_mismatch_warning = (
129129 "The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
130130 "These extra channels will be plotted as flat lines."
131131 )
132132 analyzer_sparsity = sorting_analyzer_or_templates .sparsity
133133 if channel_ids is not None :
134+ assert sparsity is None , "If 'channel_ids' is provided, 'sparsity' should be None!"
134135 channel_mask = np .tile (
135136 np .isin (sorting_analyzer_or_templates .channel_ids , channel_ids ),
136137 (len (sorting_analyzer_or_templates .unit_ids ), 1 ),
137138 )
138- sparsity = ChannelSparsity (
139+ extra_sparsity = ChannelSparsity (
139140 mask = channel_mask ,
140141 channel_ids = sorting_analyzer_or_templates .channel_ids ,
141142 unit_ids = sorting_analyzer_or_templates .unit_ids ,
142143 )
143- extra_sparsity = True
144- elif analyzer_sparsity is not None :
145- if sparsity is None :
146- sparsity = analyzer_sparsity
147- else :
148- extra_sparsity = True
149- else :
150- if sparsity is None :
151- unit_id_to_channel_ids = {
152- u : sorting_analyzer_or_templates .channel_ids for u in sorting_analyzer_or_templates .unit_ids
153- }
154- sparsity = ChannelSparsity .from_unit_id_to_channel_ids (
155- unit_id_to_channel_ids = unit_id_to_channel_ids ,
156- unit_ids = sorting_analyzer_or_templates .unit_ids ,
157- channel_ids = sorting_analyzer_or_templates .channel_ids ,
158- )
159- else :
160- assert isinstance (sparsity , ChannelSparsity ), "'sparsity' should be a ChannelSparsity object!"
144+ elif sparsity is not None :
145+ extra_sparsity = sparsity
161146
162147 if channel_ids is None :
163148 channel_ids = sorting_analyzer_or_templates .channel_ids
164149
165150 # assert provided sparsity is a subset of waveform sparsity
166- if extra_sparsity :
167- combined_mask = np .logical_or (analyzer_sparsity .mask , sparsity .mask )
168- if not np .all (np .sum (combined_mask , 1 ) - np .sum (sorting_analyzer_or_templates . sparsity .mask , 1 ) == 0 ):
151+ if extra_sparsity is not None and analyzer_sparsity is not None :
152+ combined_mask = np .logical_or (analyzer_sparsity .mask , extra_sparsity .mask )
153+ if not np .all (np .sum (combined_mask , 1 ) - np .sum (analyzer_sparsity .mask , 1 ) == 0 ):
169154 warn (sparsity_mismatch_warning )
170155
156+ final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
157+ if final_sparsity is None :
158+ final_sparsity = ChannelSparsity (
159+ mask = np .ones (
160+ (len (sorting_analyzer_or_templates .unit_ids ), len (sorting_analyzer_or_templates .channel_ids )),
161+ dtype = bool ,
162+ ),
163+ unit_ids = sorting_analyzer_or_templates .unit_ids ,
164+ channel_ids = sorting_analyzer_or_templates .channel_ids ,
165+ )
166+
171167 # get templates
172168 if isinstance (sorting_analyzer_or_templates , Templates ):
173169 templates = sorting_analyzer_or_templates .templates_array
@@ -195,9 +191,7 @@ def __init__(
195191 wf_ext = sorting_analyzer_or_templates .get_extension ("waveforms" )
196192 if wf_ext is None :
197193 raise ValueError ("plot_waveforms() needs the extension 'waveforms'" )
198- wfs_by_ids = self ._get_wfs_by_ids (
199- sorting_analyzer_or_templates , unit_ids , sparsity , extra_sparsity = extra_sparsity
200- )
194+ wfs_by_ids = self ._get_wfs_by_ids (sorting_analyzer_or_templates , unit_ids , extra_sparsity = extra_sparsity )
201195 else :
202196 wfs_by_ids = None
203197
@@ -207,7 +201,8 @@ def __init__(
207201 nbefore = nbefore ,
208202 unit_ids = unit_ids ,
209203 channel_ids = channel_ids ,
210- sparsity = sparsity ,
204+ final_sparsity = final_sparsity ,
205+ extra_sparsity = extra_sparsity ,
211206 unit_colors = unit_colors ,
212207 channel_locations = channel_locations ,
213208 scale = scale ,
@@ -234,7 +229,6 @@ def __init__(
234229 alpha_templates = alpha_templates ,
235230 hide_unit_selector = hide_unit_selector ,
236231 plot_legend = plot_legend ,
237- extra_sparsity = extra_sparsity ,
238232 )
239233 BaseWidget .__init__ (self , plot_data , backend = backend , ** backend_kwargs )
240234
@@ -269,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
269263 ax = self .axes .flatten ()[i ]
270264 color = dp .unit_colors [unit_id ]
271265
272- chan_inds = dp .sparsity .unit_id_to_channel_indices [unit_id ]
266+ chan_inds = dp .final_sparsity .unit_id_to_channel_indices [unit_id ]
273267 xvectors_flat = xvectors [:, chan_inds ].T .flatten ()
274268
275269 # plot waveforms
@@ -501,28 +495,27 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
501495 if backend_kwargs ["display" ]:
502496 display (self .widget )
503497
504- def _get_wfs_by_ids (self , sorting_analyzer , unit_ids , sparsity , extra_sparsity = False ):
498+ def _get_wfs_by_ids (self , sorting_analyzer , unit_ids , extra_sparsity ):
505499 wfs_by_ids = {}
506500 wf_ext = sorting_analyzer .get_extension ("waveforms" )
507501 for unit_id in unit_ids :
508502 unit_index = list (sorting_analyzer .unit_ids ).index (unit_id )
509- if not extra_sparsity :
510- # get waveforms with default sparsity
511- if sorting_analyzer .is_sparse ():
512- wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
513- else :
514- wfs = wf_ext .get_waveforms_one_unit (unit_id )
515- wfs = wfs [:, :, sparsity .mask [unit_index ]]
503+ if extra_sparsity is None :
504+ wfs = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
516505 else :
517506 # in this case we have to construct waveforms based on the extra sparsity and add the
518507 # sparse waveforms on the valid channels
508+ if sorting_analyzer .is_sparse ():
509+ original_mask = sorting_analyzer .sparsity .mask [unit_index ]
510+ else :
511+ original_mask = np .ones (len (sorting_analyzer .channel_ids ), dtype = bool )
519512 wfs_orig = wf_ext .get_waveforms_one_unit (unit_id , force_dense = False )
520513 wfs = np .zeros (
521- (wfs_orig .shape [0 ], wfs_orig .shape [1 ], sparsity .mask [unit_index ].sum ()), dtype = wfs_orig .dtype
514+ (wfs_orig .shape [0 ], wfs_orig .shape [1 ], extra_sparsity .mask [unit_index ].sum ()), dtype = wfs_orig .dtype
522515 )
523516 # fill in the existing waveforms channels
524- valid_wfs_indices = sparsity .mask [unit_index ][sorting_analyzer . sparsity . mask [ unit_index ] ]
525- valid_extra_indices = sorting_analyzer . sparsity . mask [ unit_index ][ sparsity .mask [unit_index ]]
517+ valid_wfs_indices = extra_sparsity .mask [unit_index ][original_mask ]
518+ valid_extra_indices = original_mask [ extra_sparsity .mask [unit_index ]]
526519 wfs [:, :, valid_extra_indices ] = wfs_orig [:, :, valid_wfs_indices ]
527520
528521 wfs_by_ids [unit_id ] = wfs
@@ -592,7 +585,7 @@ def _update_plot(self, change):
592585
593586 if data_plot ["plot_waveforms" ]:
594587 wfs_by_ids = self ._get_wfs_by_ids (
595- self .sorting_analyzer , unit_ids , data_plot [ "sparsity" ], extra_sparsity = data_plot ["extra_sparsity" ]
588+ self .sorting_analyzer , unit_ids , extra_sparsity = data_plot ["extra_sparsity" ]
596589 )
597590 data_plot ["wfs_by_ids" ] = wfs_by_ids
598591
@@ -638,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):
638631
639632 # TODO this could be done with probeinterface plotting plotting tools!!
640633 for unit in unit_ids :
641- channel_inds = self .data_plot ["sparsity " ].unit_id_to_channel_indices [unit ]
634+ channel_inds = self .data_plot ["final_sparsity " ].unit_id_to_channel_indices [unit ]
642635 ax .plot (
643636 channel_locations [channel_inds , 0 ],
644637 channel_locations [channel_inds , 1 ],
0 commit comments