Skip to content

Commit ab59a93

Browse files
authored
Merge pull request #2850 from DradeAW/patch-1
Fix bug in plot templates
2 parents c9fc8e1 + 600f25e commit ab59a93

File tree

4 files changed

+108
-74
lines changed

4 files changed

+108
-74
lines changed

src/spikeinterface/widgets/tests/test_widgets.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,16 @@ def test_plot_unit_waveforms(self):
190190
backend=backend,
191191
**self.backend_kwargs[backend],
192192
)
193-
# test "larger" sparsity
194-
with self.assertRaises(AssertionError):
193+
# channel ids
194+
sw.plot_unit_waveforms(
195+
self.sorting_analyzer_sparse,
196+
channel_ids=self.sorting_analyzer_sparse.channel_ids[::3],
197+
unit_ids=unit_ids,
198+
backend=backend,
199+
**self.backend_kwargs[backend],
200+
)
201+
# test warning with "larger" sparsity
202+
with self.assertWarns(UserWarning):
195203
sw.plot_unit_waveforms(
196204
self.sorting_analyzer_sparse,
197205
sparsity=self.sparsity_large,
@@ -205,18 +213,18 @@ def test_plot_unit_templates(self):
205213
for backend in possible_backends:
206214
if backend not in self.skip_backends:
207215
print(f"Testing backend {backend}")
208-
print("Dense")
216+
# dense
209217
sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend])
210218
unit_ids = self.sorting.unit_ids[:6]
211-
print("Dense + radius")
219+
# dense + radius
212220
sw.plot_unit_templates(
213221
self.sorting_analyzer_dense,
214222
sparsity=self.sparsity_radius,
215223
unit_ids=unit_ids,
216224
backend=backend,
217225
**self.backend_kwargs[backend],
218226
)
219-
print("Dense + best")
227+
# dense + best
220228
sw.plot_unit_templates(
221229
self.sorting_analyzer_dense,
222230
sparsity=self.sparsity_best,
@@ -225,15 +233,13 @@ def test_plot_unit_templates(self):
225233
**self.backend_kwargs[backend],
226234
)
227235
# test different shadings
228-
print("Sparse")
229236
sw.plot_unit_templates(
230237
self.sorting_analyzer_sparse,
231238
unit_ids=unit_ids,
232239
templates_percentile_shading=None,
233240
backend=backend,
234241
**self.backend_kwargs[backend],
235242
)
236-
print("Sparse2")
237243
sw.plot_unit_templates(
238244
self.sorting_analyzer_sparse,
239245
unit_ids=unit_ids,
@@ -242,8 +248,6 @@ def test_plot_unit_templates(self):
242248
backend=backend,
243249
**self.backend_kwargs[backend],
244250
)
245-
# test different shadings
246-
print("Sparse3")
247251
sw.plot_unit_templates(
248252
self.sorting_analyzer_sparse,
249253
unit_ids=unit_ids,
@@ -252,15 +256,14 @@ def test_plot_unit_templates(self):
252256
shade_templates=False,
253257
**self.backend_kwargs[backend],
254258
)
255-
print("Sparse4")
256259
sw.plot_unit_templates(
257260
self.sorting_analyzer_sparse,
258261
unit_ids=unit_ids,
259262
templates_percentile_shading=0.1,
260263
backend=backend,
261264
**self.backend_kwargs[backend],
262265
)
263-
print("Extra sparsity")
266+
# extra sparsity
264267
sw.plot_unit_templates(
265268
self.sorting_analyzer_sparse,
266269
sparsity=self.sparsity_strict,
@@ -269,8 +272,18 @@ def test_plot_unit_templates(self):
269272
backend=backend,
270273
**self.backend_kwargs[backend],
271274
)
275+
# channel ids
276+
sw.plot_unit_templates(
277+
self.sorting_analyzer_sparse,
278+
channel_ids=self.sorting_analyzer_sparse.channel_ids[::3],
279+
unit_ids=unit_ids,
280+
templates_percentile_shading=[1, 10, 90, 99],
281+
backend=backend,
282+
**self.backend_kwargs[backend],
283+
)
284+
272285
# test "larger" sparsity
273-
with self.assertRaises(AssertionError):
286+
with self.assertWarns(UserWarning):
274287
sw.plot_unit_templates(
275288
self.sorting_analyzer_sparse,
276289
sparsity=self.sparsity_large,

src/spikeinterface/widgets/unit_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
2424
assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"
2525

2626
# ensure serializable for sortingview
27-
unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids
28-
unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices
27+
unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids
28+
unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices
2929

3030
unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids)
3131

src/spikeinterface/widgets/unit_waveforms.py

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -119,38 +119,50 @@ def __init__(
119119

120120
if unit_ids is None:
121121
unit_ids = sorting_analyzer_or_templates.unit_ids
122-
if channel_ids is None:
123-
channel_ids = sorting_analyzer_or_templates.channel_ids
124122
if unit_colors is None:
125123
unit_colors = get_unit_colors(sorting_analyzer_or_templates)
126124

127-
channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids]
128-
channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices]
129-
extra_sparsity = False
130-
if sorting_analyzer_or_templates.sparsity is not None:
131-
if sparsity is None:
132-
sparsity = sorting_analyzer_or_templates.sparsity
133-
else:
134-
# assert provided sparsity is a subset of waveform sparsity
135-
combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask)
136-
assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), (
137-
"The provided 'sparsity' needs to include only the sparse channels "
138-
"used to extract waveforms (for example, by using a smaller 'radius_um')."
139-
)
140-
extra_sparsity = True
141-
else:
142-
if sparsity is None:
143-
# in this case, we construct a dense sparsity
144-
unit_id_to_channel_ids = {
145-
u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids
146-
}
147-
sparsity = ChannelSparsity.from_unit_id_to_channel_ids(
148-
unit_id_to_channel_ids=unit_id_to_channel_ids,
149-
unit_ids=sorting_analyzer_or_templates.unit_ids,
150-
channel_ids=sorting_analyzer_or_templates.channel_ids,
151-
)
152-
else:
153-
assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!"
125+
channel_locations = sorting_analyzer_or_templates.get_channel_locations()
126+
extra_sparsity = None
127+
# handle sparsity
128+
sparsity_mismatch_warning = (
129+
"The provided 'sparsity' includes additional channels not in the analyzer sparsity. "
130+
"These extra channels will be plotted as flat lines."
131+
)
132+
analyzer_sparsity = sorting_analyzer_or_templates.sparsity
133+
if channel_ids is not None:
134+
assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!"
135+
channel_mask = np.tile(
136+
np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids),
137+
(len(sorting_analyzer_or_templates.unit_ids), 1),
138+
)
139+
extra_sparsity = ChannelSparsity(
140+
mask=channel_mask,
141+
channel_ids=sorting_analyzer_or_templates.channel_ids,
142+
unit_ids=sorting_analyzer_or_templates.unit_ids,
143+
)
144+
elif sparsity is not None:
145+
extra_sparsity = sparsity
146+
147+
if channel_ids is None:
148+
channel_ids = sorting_analyzer_or_templates.channel_ids
149+
150+
# assert provided sparsity is a subset of waveform sparsity
151+
if extra_sparsity is not None and analyzer_sparsity is not None:
152+
combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask)
153+
if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0):
154+
warn(sparsity_mismatch_warning)
155+
156+
final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity
157+
if final_sparsity is None:
158+
final_sparsity = ChannelSparsity(
159+
mask=np.ones(
160+
(len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)),
161+
dtype=bool,
162+
),
163+
unit_ids=sorting_analyzer_or_templates.unit_ids,
164+
channel_ids=sorting_analyzer_or_templates.channel_ids,
165+
)
154166

155167
# get templates
156168
if isinstance(sorting_analyzer_or_templates, Templates):
@@ -174,42 +186,23 @@ def __init__(
174186
templates_percentile_shading = None
175187
templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading)
176188

177-
wfs_by_ids = {}
178189
if plot_waveforms:
179190
# this must be a sorting_analyzer
180191
wf_ext = sorting_analyzer_or_templates.get_extension("waveforms")
181192
if wf_ext is None:
182193
raise ValueError("plot_waveforms() needs the extension 'waveforms'")
183-
for unit_id in unit_ids:
184-
unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id)
185-
if not extra_sparsity:
186-
if sorting_analyzer_or_templates.is_sparse():
187-
# wfs = we.get_waveforms(unit_id)
188-
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
189-
else:
190-
# wfs = we.get_waveforms(unit_id, sparsity=sparsity)
191-
wfs = wf_ext.get_waveforms_one_unit(unit_id)
192-
wfs = wfs[:, :, sparsity.mask[unit_index]]
193-
else:
194-
# in this case we have to slice the waveform sparsity based on the extra sparsity
195-
# first get the sparse waveforms
196-
# wfs = we.get_waveforms(unit_id)
197-
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
198-
# find additional slice to apply to sparse waveforms
199-
(wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index])
200-
(extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index])
201-
(extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices))
202-
# apply extra sparsity
203-
wfs = wfs[:, :, extra_slice]
204-
wfs_by_ids[unit_id] = wfs
194+
wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity)
195+
else:
196+
wfs_by_ids = None
205197

206198
plot_data = dict(
207199
sorting_analyzer_or_templates=sorting_analyzer_or_templates,
208200
sampling_frequency=sorting_analyzer_or_templates.sampling_frequency,
209201
nbefore=nbefore,
210202
unit_ids=unit_ids,
211203
channel_ids=channel_ids,
212-
sparsity=sparsity,
204+
final_sparsity=final_sparsity,
205+
extra_sparsity=extra_sparsity,
213206
unit_colors=unit_colors,
214207
channel_locations=channel_locations,
215208
scale=scale,
@@ -270,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
270263
ax = self.axes.flatten()[i]
271264
color = dp.unit_colors[unit_id]
272265

273-
chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id]
266+
chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id]
274267
xvectors_flat = xvectors[:, chan_inds].T.flatten()
275268

276269
# plot waveforms
@@ -502,6 +495,32 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
502495
if backend_kwargs["display"]:
503496
display(self.widget)
504497

498+
def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity):
499+
wfs_by_ids = {}
500+
wf_ext = sorting_analyzer.get_extension("waveforms")
501+
for unit_id in unit_ids:
502+
unit_index = list(sorting_analyzer.unit_ids).index(unit_id)
503+
if extra_sparsity is None:
504+
wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
505+
else:
506+
# in this case we have to construct waveforms based on the extra sparsity and add the
507+
# sparse waveforms on the valid channels
508+
if sorting_analyzer.is_sparse():
509+
original_mask = sorting_analyzer.sparsity.mask[unit_index]
510+
else:
511+
original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool)
512+
wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False)
513+
wfs = np.zeros(
514+
(wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype
515+
)
516+
# fill in the existing waveforms channels
517+
valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask]
518+
valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]]
519+
wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices]
520+
521+
wfs_by_ids[unit_id] = wfs
522+
return wfs_by_ids
523+
505524
def _get_template_shadings(self, unit_ids, templates_percentile_shading):
506525
templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")
507526

@@ -538,6 +557,8 @@ def _update_plot(self, change):
538557
hide_axis = self.hide_axis_button.value
539558
do_shading = self.template_shading_button.value
540559

560+
data_plot = self.next_data_plot
561+
541562
if self.sorting_analyzer is not None:
542563
templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average")
543564
templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"])
@@ -549,7 +570,6 @@ def _update_plot(self, change):
549570
channel_locations = self.templates.get_channel_locations()
550571

551572
# matplotlib next_data_plot dict update at each call
552-
data_plot = self.next_data_plot
553573
data_plot["unit_ids"] = unit_ids
554574
data_plot["templates"] = templates
555575
data_plot["templates_shading"] = templates_shadings
@@ -564,10 +584,10 @@ def _update_plot(self, change):
564584
data_plot["scalebar"] = self.scalebar.value
565585

566586
if data_plot["plot_waveforms"]:
567-
wf_ext = self.sorting_analyzer.get_extension("waveforms")
568-
data_plot["wfs_by_ids"] = {
569-
unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids
570-
}
587+
wfs_by_ids = self._get_wfs_by_ids(
588+
self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"]
589+
)
590+
data_plot["wfs_by_ids"] = wfs_by_ids
571591

572592
# TODO option for plot_legend
573593
backend_kwargs = {}
@@ -611,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids):
611631

612632
# TODO this could be done with probeinterface plotting plotting tools!!
613633
for unit in unit_ids:
614-
channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit]
634+
channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit]
615635
ax.plot(
616636
channel_locations[channel_inds, 0],
617637
channel_locations[channel_inds, 1],

src/spikeinterface/widgets/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def array_to_image(
151151
output_image : 3D numpy array
152152
153153
"""
154+
import matplotlib.pyplot as plt
154155

155156
from scipy.ndimage import zoom
156157

0 commit comments

Comments
 (0)