Skip to content

Commit e431bd6

Browse files
authored
Merge pull request #3022 from JoeZiminski/add_plot_spikes_on_probe_widget
Add peaks_on_probe widget.
2 parents 0ec6825 + 7ab068b commit e431bd6

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
6+
from .base import BaseWidget, to_attr
7+
8+
9+
class PeaksOnProbeWidget(BaseWidget):
10+
"""
11+
Generate a plot of spike peaks showing their location on a plot
12+
of the probe. Color scaling represents spike amplitude.
13+
14+
The generated plot overlays the estimated position of a spike peak
15+
(as a single point for each peak) onto a plot of the probe. The
16+
dimensions of the plot are x axis: probe width, y axis: probe depth.
17+
18+
Plots of different sets of peaks can be created on subplots, by
19+
passing a list of peaks and corresponding peak locations.
20+
21+
Parameters
22+
----------
23+
recording : Recording
24+
A SpikeInterface recording object.
25+
peaks : np.array | list[np.ndarray]
26+
SpikeInterface 'peaks' array created with `detect_peaks()`,
27+
an array of length num_peaks with entries:
28+
(sample_index, channel_index, amplitude, segment_index)
29+
To plot different sets of peaks in subplots, pass a list of peaks, each
30+
with a corresponding entry in a list passed to `peak_locations`.
31+
peak_locations : np.array | list[np.ndarray]
32+
A SpikeInterface 'peak_locations' array created with `localize_peaks()`.
33+
an array of length num_peaks with entries: (x, y)
34+
To plot multiple peaks in subplots, pass a list of `peak_locations`
35+
here with each entry having a corresponding `peaks`.
36+
segment_index : None | int, default: None
37+
If set, only peaks from this recording segment will be used.
38+
time_range : None | Tuple, default: None
39+
The time period over which to include peaks. If `None`, peaks
40+
across the entire recording will be shown.
41+
ylim : None | Tuple, default: None
42+
The y-axis limits (i.e. the probe depth). If `None`, the entire
43+
probe will be displayed.
44+
decimate : int, default: 5
45+
For performance reasons, every nth peak is shown on the plot,
46+
where n is set by decimate. To plot all peaks, set `decimate=1`.
47+
"""
48+
49+
def __init__(
50+
self,
51+
recording,
52+
peaks,
53+
peak_locations,
54+
segment_index=None,
55+
time_range=None,
56+
ylim=None,
57+
decimate=5,
58+
backend=None,
59+
**backend_kwargs,
60+
):
61+
data_plot = dict(
62+
recording=recording,
63+
peaks=peaks,
64+
peak_locations=peak_locations,
65+
segment_index=segment_index,
66+
time_range=time_range,
67+
ylim=ylim,
68+
decimate=decimate,
69+
)
70+
71+
BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs)
72+
73+
def plot_matplotlib(self, data_plot, **backend_kwargs):
74+
import matplotlib.pyplot as plt
75+
from .utils_matplotlib import make_mpl_figure
76+
from spikeinterface.widgets import plot_probe_map
77+
78+
dp = to_attr(data_plot)
79+
80+
peaks, peak_locations = self._check_and_format_inputs(
81+
dp.peaks,
82+
dp.peak_locations,
83+
)
84+
fs = dp.recording.get_sampling_frequency()
85+
num_plots = len(peaks)
86+
87+
# Set the maximum time to the end time of the longest segment
88+
if dp.time_range is None:
89+
90+
time_range = self._get_min_and_max_times_in_recording(dp.recording)
91+
else:
92+
time_range = dp.time_range
93+
94+
## Create the figure and axes
95+
if backend_kwargs["figsize"] is None:
96+
backend_kwargs.update(dict(figsize=(12, 8)))
97+
98+
self.figure, self.axes, self.ax = make_mpl_figure(num_axes=num_plots, **backend_kwargs)
99+
self.axes = self.axes[0]
100+
101+
# Plot each passed peaks / peak_locations over the probe on a separate subplot
102+
for ax_idx, (peaks_to_plot, peak_locs_to_plot) in enumerate(zip(peaks, peak_locations)):
103+
104+
ax = self.axes[ax_idx]
105+
plot_probe_map(dp.recording, ax=ax)
106+
107+
time_mask = self._get_peaks_time_mask(dp.recording, time_range, peaks_to_plot)
108+
109+
if dp.segment_index is not None:
110+
segment_mask = peaks_to_plot["segment_index"] == dp.segment_index
111+
mask = time_mask & segment_mask
112+
else:
113+
mask = time_mask
114+
115+
if not any(mask):
116+
raise ValueError(
117+
"No peaks within the time and segment mask found. Change `time_range` or `segment_index`"
118+
)
119+
120+
# only plot every nth peak
121+
peak_slice = slice(None, None, dp.decimate)
122+
123+
# Find the amplitudes for the colormap scaling
124+
# (intensity represents amplitude)
125+
amps = np.abs(peaks_to_plot["amplitude"][mask][peak_slice])
126+
amps /= np.quantile(amps, 0.95)
127+
cmap = plt.get_cmap("inferno")(amps)
128+
color_kwargs = dict(alpha=0.2, s=2, c=cmap)
129+
130+
# Plot the peaks over the plot, and set the y-axis limits.
131+
ax.scatter(
132+
peak_locs_to_plot["x"][mask][peak_slice], peak_locs_to_plot["y"][mask][peak_slice], **color_kwargs
133+
)
134+
135+
if dp.ylim is None:
136+
padding = 25 # arbitary padding just to give some space around highests and lowest peaks on the plot
137+
ylim = (np.min(peak_locs_to_plot["y"]) - padding, np.max(peak_locs_to_plot["y"]) + padding)
138+
else:
139+
ylim = dp.ylim
140+
141+
ax.set_ylim(ylim[0], ylim[1])
142+
143+
self.figure.suptitle(f"Peaks on Probe Plot")
144+
145+
def _get_peaks_time_mask(self, recording, time_range, peaks_to_plot):
146+
"""
147+
Return a mask of `True` where the peak is within the given time range
148+
and `False` otherwise.
149+
150+
This is a little complex, as each segment can have different start /
151+
end times. For each segment, find the time bounds relative to that
152+
segment time and fill the `time_mask` one segment at a time.
153+
"""
154+
time_mask = np.zeros(peaks_to_plot.size, dtype=bool)
155+
156+
for seg_idx in range(recording.get_num_segments()):
157+
158+
segment = recording.select_segments(seg_idx)
159+
160+
t_start_sample = segment.time_to_sample_index(time_range[0])
161+
t_stop_sample = segment.time_to_sample_index(time_range[1])
162+
163+
seg_mask = peaks_to_plot["segment_index"] == seg_idx
164+
165+
time_mask[seg_mask] = (t_start_sample < peaks_to_plot[seg_mask]["sample_index"]) & (
166+
peaks_to_plot[seg_mask]["sample_index"] < t_stop_sample
167+
)
168+
169+
return time_mask
170+
171+
def _get_min_and_max_times_in_recording(self, recording):
172+
"""
173+
Find the maximum and minimum time across all segments in the recording.
174+
For example if the segment times are (10-100 s, 0 - 50s) the
175+
min and max times are (0, 100)
176+
"""
177+
t_starts = []
178+
t_stops = []
179+
for seg_idx in range(recording.get_num_segments()):
180+
181+
segment = recording.select_segments(seg_idx)
182+
183+
t_starts.append(segment.sample_index_to_time(0))
184+
185+
t_stops.append(segment.sample_index_to_time(segment.get_num_samples() - 1))
186+
187+
time_range = (np.min(t_starts), np.max(t_stops))
188+
189+
return time_range
190+
191+
def _check_and_format_inputs(self, peaks, peak_locations):
192+
"""
193+
Check that the inpust are in expected form. Corresponding peaks
194+
and peak_locations of same size and format must be provided.
195+
"""
196+
types_are_list = [isinstance(peaks, list), isinstance(peak_locations, list)]
197+
198+
if not all(types_are_list):
199+
if any(types_are_list):
200+
raise ValueError("`peaks` and `peak_locations` must either be both lists or both not lists.")
201+
peaks = [peaks]
202+
peak_locations = [peak_locations]
203+
204+
if len(peaks) != len(peak_locations):
205+
raise ValueError(
206+
"If `peaks` and `peak_locations` are lists, they must contain "
207+
"the same number of (corresponding) peaks and peak locations."
208+
)
209+
210+
for idx, (peak, peak_loc) in enumerate(zip(peaks, peak_locations)):
211+
if peak.size != peak_loc.size:
212+
raise ValueError(
213+
f"The number of peaks and peak_locations do not "
214+
f"match for the {idx} input. For each spike peak, there "
215+
f"must be a corresponding peak location"
216+
)
217+
218+
return peaks, peak_locations

src/spikeinterface/widgets/widget_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .motion import MotionWidget, MotionInfoWidget
1414
from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget
1515
from .peak_activity import PeakActivityMapWidget
16+
from .peaks_on_probe import PeaksOnProbeWidget
1617
from .potential_merges import PotentialMergesWidget
1718
from .probe_map import ProbeMapWidget
1819
from .quality_metrics import QualityMetricsWidget
@@ -50,6 +51,7 @@
5051
MultiCompAgreementBySorterWidget,
5152
MultiCompGraphWidget,
5253
PeakActivityMapWidget,
54+
PeaksOnProbeWidget,
5355
PotentialMergesWidget,
5456
ProbeMapWidget,
5557
QualityMetricsWidget,
@@ -123,6 +125,7 @@
123125
plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget
124126
plot_multicomparison_graph = MultiCompGraphWidget
125127
plot_peak_activity = PeakActivityMapWidget
128+
plot_peaks_on_probe = PeaksOnProbeWidget
126129
plot_potential_merges = PotentialMergesWidget
127130
plot_probe_map = ProbeMapWidget
128131
plot_quality_metrics = QualityMetricsWidget

0 commit comments

Comments
 (0)