|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | + |
| 6 | +from .base import BaseWidget, to_attr |
| 7 | + |
| 8 | + |
| 9 | +class PeaksOnProbeWidget(BaseWidget): |
| 10 | + """ |
| 11 | + Generate a plot of spike peaks showing their location on a plot |
| 12 | + of the probe. Color scaling represents spike amplitude. |
| 13 | +
|
| 14 | + The generated plot overlays the estimated position of a spike peak |
| 15 | + (as a single point for each peak) onto a plot of the probe. The |
| 16 | + dimensions of the plot are x axis: probe width, y axis: probe depth. |
| 17 | +
|
| 18 | + Plots of different sets of peaks can be created on subplots, by |
| 19 | + passing a list of peaks and corresponding peak locations. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + recording : Recording |
| 24 | + A SpikeInterface recording object. |
| 25 | + peaks : np.array | list[np.ndarray] |
| 26 | + SpikeInterface 'peaks' array created with `detect_peaks()`, |
| 27 | + an array of length num_peaks with entries: |
| 28 | + (sample_index, channel_index, amplitude, segment_index) |
| 29 | + To plot different sets of peaks in subplots, pass a list of peaks, each |
| 30 | + with a corresponding entry in a list passed to `peak_locations`. |
| 31 | + peak_locations : np.array | list[np.ndarray] |
| 32 | + A SpikeInterface 'peak_locations' array created with `localize_peaks()`. |
| 33 | + an array of length num_peaks with entries: (x, y) |
| 34 | + To plot multiple peaks in subplots, pass a list of `peak_locations` |
| 35 | + here with each entry having a corresponding `peaks`. |
| 36 | + segment_index : None | int, default: None |
| 37 | + If set, only peaks from this recording segment will be used. |
| 38 | + time_range : None | Tuple, default: None |
| 39 | + The time period over which to include peaks. If `None`, peaks |
| 40 | + across the entire recording will be shown. |
| 41 | + ylim : None | Tuple, default: None |
| 42 | + The y-axis limits (i.e. the probe depth). If `None`, the entire |
| 43 | + probe will be displayed. |
| 44 | + decimate : int, default: 5 |
| 45 | + For performance reasons, every nth peak is shown on the plot, |
| 46 | + where n is set by decimate. To plot all peaks, set `decimate=1`. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + recording, |
| 52 | + peaks, |
| 53 | + peak_locations, |
| 54 | + segment_index=None, |
| 55 | + time_range=None, |
| 56 | + ylim=None, |
| 57 | + decimate=5, |
| 58 | + backend=None, |
| 59 | + **backend_kwargs, |
| 60 | + ): |
| 61 | + data_plot = dict( |
| 62 | + recording=recording, |
| 63 | + peaks=peaks, |
| 64 | + peak_locations=peak_locations, |
| 65 | + segment_index=segment_index, |
| 66 | + time_range=time_range, |
| 67 | + ylim=ylim, |
| 68 | + decimate=decimate, |
| 69 | + ) |
| 70 | + |
| 71 | + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) |
| 72 | + |
| 73 | + def plot_matplotlib(self, data_plot, **backend_kwargs): |
| 74 | + import matplotlib.pyplot as plt |
| 75 | + from .utils_matplotlib import make_mpl_figure |
| 76 | + from spikeinterface.widgets import plot_probe_map |
| 77 | + |
| 78 | + dp = to_attr(data_plot) |
| 79 | + |
| 80 | + peaks, peak_locations = self._check_and_format_inputs( |
| 81 | + dp.peaks, |
| 82 | + dp.peak_locations, |
| 83 | + ) |
| 84 | + fs = dp.recording.get_sampling_frequency() |
| 85 | + num_plots = len(peaks) |
| 86 | + |
| 87 | + # Set the maximum time to the end time of the longest segment |
| 88 | + if dp.time_range is None: |
| 89 | + |
| 90 | + time_range = self._get_min_and_max_times_in_recording(dp.recording) |
| 91 | + else: |
| 92 | + time_range = dp.time_range |
| 93 | + |
| 94 | + ## Create the figure and axes |
| 95 | + if backend_kwargs["figsize"] is None: |
| 96 | + backend_kwargs.update(dict(figsize=(12, 8))) |
| 97 | + |
| 98 | + self.figure, self.axes, self.ax = make_mpl_figure(num_axes=num_plots, **backend_kwargs) |
| 99 | + self.axes = self.axes[0] |
| 100 | + |
| 101 | + # Plot each passed peaks / peak_locations over the probe on a separate subplot |
| 102 | + for ax_idx, (peaks_to_plot, peak_locs_to_plot) in enumerate(zip(peaks, peak_locations)): |
| 103 | + |
| 104 | + ax = self.axes[ax_idx] |
| 105 | + plot_probe_map(dp.recording, ax=ax) |
| 106 | + |
| 107 | + time_mask = self._get_peaks_time_mask(dp.recording, time_range, peaks_to_plot) |
| 108 | + |
| 109 | + if dp.segment_index is not None: |
| 110 | + segment_mask = peaks_to_plot["segment_index"] == dp.segment_index |
| 111 | + mask = time_mask & segment_mask |
| 112 | + else: |
| 113 | + mask = time_mask |
| 114 | + |
| 115 | + if not any(mask): |
| 116 | + raise ValueError( |
| 117 | + "No peaks within the time and segment mask found. Change `time_range` or `segment_index`" |
| 118 | + ) |
| 119 | + |
| 120 | + # only plot every nth peak |
| 121 | + peak_slice = slice(None, None, dp.decimate) |
| 122 | + |
| 123 | + # Find the amplitudes for the colormap scaling |
| 124 | + # (intensity represents amplitude) |
| 125 | + amps = np.abs(peaks_to_plot["amplitude"][mask][peak_slice]) |
| 126 | + amps /= np.quantile(amps, 0.95) |
| 127 | + cmap = plt.get_cmap("inferno")(amps) |
| 128 | + color_kwargs = dict(alpha=0.2, s=2, c=cmap) |
| 129 | + |
| 130 | + # Plot the peaks over the plot, and set the y-axis limits. |
| 131 | + ax.scatter( |
| 132 | + peak_locs_to_plot["x"][mask][peak_slice], peak_locs_to_plot["y"][mask][peak_slice], **color_kwargs |
| 133 | + ) |
| 134 | + |
| 135 | + if dp.ylim is None: |
| 136 | + padding = 25 # arbitary padding just to give some space around highests and lowest peaks on the plot |
| 137 | + ylim = (np.min(peak_locs_to_plot["y"]) - padding, np.max(peak_locs_to_plot["y"]) + padding) |
| 138 | + else: |
| 139 | + ylim = dp.ylim |
| 140 | + |
| 141 | + ax.set_ylim(ylim[0], ylim[1]) |
| 142 | + |
| 143 | + self.figure.suptitle(f"Peaks on Probe Plot") |
| 144 | + |
| 145 | + def _get_peaks_time_mask(self, recording, time_range, peaks_to_plot): |
| 146 | + """ |
| 147 | + Return a mask of `True` where the peak is within the given time range |
| 148 | + and `False` otherwise. |
| 149 | +
|
| 150 | + This is a little complex, as each segment can have different start / |
| 151 | + end times. For each segment, find the time bounds relative to that |
| 152 | + segment time and fill the `time_mask` one segment at a time. |
| 153 | + """ |
| 154 | + time_mask = np.zeros(peaks_to_plot.size, dtype=bool) |
| 155 | + |
| 156 | + for seg_idx in range(recording.get_num_segments()): |
| 157 | + |
| 158 | + segment = recording.select_segments(seg_idx) |
| 159 | + |
| 160 | + t_start_sample = segment.time_to_sample_index(time_range[0]) |
| 161 | + t_stop_sample = segment.time_to_sample_index(time_range[1]) |
| 162 | + |
| 163 | + seg_mask = peaks_to_plot["segment_index"] == seg_idx |
| 164 | + |
| 165 | + time_mask[seg_mask] = (t_start_sample < peaks_to_plot[seg_mask]["sample_index"]) & ( |
| 166 | + peaks_to_plot[seg_mask]["sample_index"] < t_stop_sample |
| 167 | + ) |
| 168 | + |
| 169 | + return time_mask |
| 170 | + |
| 171 | + def _get_min_and_max_times_in_recording(self, recording): |
| 172 | + """ |
| 173 | + Find the maximum and minimum time across all segments in the recording. |
| 174 | + For example if the segment times are (10-100 s, 0 - 50s) the |
| 175 | + min and max times are (0, 100) |
| 176 | + """ |
| 177 | + t_starts = [] |
| 178 | + t_stops = [] |
| 179 | + for seg_idx in range(recording.get_num_segments()): |
| 180 | + |
| 181 | + segment = recording.select_segments(seg_idx) |
| 182 | + |
| 183 | + t_starts.append(segment.sample_index_to_time(0)) |
| 184 | + |
| 185 | + t_stops.append(segment.sample_index_to_time(segment.get_num_samples() - 1)) |
| 186 | + |
| 187 | + time_range = (np.min(t_starts), np.max(t_stops)) |
| 188 | + |
| 189 | + return time_range |
| 190 | + |
| 191 | + def _check_and_format_inputs(self, peaks, peak_locations): |
| 192 | + """ |
| 193 | + Check that the inpust are in expected form. Corresponding peaks |
| 194 | + and peak_locations of same size and format must be provided. |
| 195 | + """ |
| 196 | + types_are_list = [isinstance(peaks, list), isinstance(peak_locations, list)] |
| 197 | + |
| 198 | + if not all(types_are_list): |
| 199 | + if any(types_are_list): |
| 200 | + raise ValueError("`peaks` and `peak_locations` must either be both lists or both not lists.") |
| 201 | + peaks = [peaks] |
| 202 | + peak_locations = [peak_locations] |
| 203 | + |
| 204 | + if len(peaks) != len(peak_locations): |
| 205 | + raise ValueError( |
| 206 | + "If `peaks` and `peak_locations` are lists, they must contain " |
| 207 | + "the same number of (corresponding) peaks and peak locations." |
| 208 | + ) |
| 209 | + |
| 210 | + for idx, (peak, peak_loc) in enumerate(zip(peaks, peak_locations)): |
| 211 | + if peak.size != peak_loc.size: |
| 212 | + raise ValueError( |
| 213 | + f"The number of peaks and peak_locations do not " |
| 214 | + f"match for the {idx} input. For each spike peak, there " |
| 215 | + f"must be a corresponding peak location" |
| 216 | + ) |
| 217 | + |
| 218 | + return peaks, peak_locations |
0 commit comments