diff --git a/spikeinterface_gui/metricsview.py b/spikeinterface_gui/metricsview.py index d9075cb..d61f88b 100644 --- a/spikeinterface_gui/metricsview.py +++ b/spikeinterface_gui/metricsview.py @@ -111,27 +111,25 @@ def _qt_refresh(self): scatter.setData(x=values2, y=values1) - visible_unit_ids = self.controller.get_visible_unit_ids() - visible_unit_ids = self.controller.get_visible_unit_indices() - for unit_ind, unit_id in self.controller.iter_visible_units(): color = self.get_unit_color(unit_id) - scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color) + if (not np.isnan(values2[unit_ind])) and (not np.isnan(values1[unit_ind])): + scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color) - # self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color) - # self.scatter_select.setData(selected_scatter_x, selected_scatter_y) elif c == r: values1 = units_table[visible_metrics[r]].values + values1_no_nans = values1[~np.isnan(values1)] - count, bins = np.histogram(values1, bins=self.settings['num_bins']) + count, bins = np.histogram(values1_no_nans, bins=self.settings['num_bins']) curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=white_brush, pen=white_brush) plot.addItem(curve) for unit_ind, unit_id in self.controller.iter_visible_units(): x = values1[unit_ind] color = self.get_unit_color(unit_id) - line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color) - plot.addItem(line) + if not np.isnan(x): + line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color) + plot.addItem(line) def _qt_select_metrics(self): if not self.tree_visible_metrics.isVisible(): @@ -212,6 +210,8 @@ def _panel_refresh(self): col2 = visible_metrics[c] values1 = units_table[col1].values values2 = units_table[col2].values + values1_no_nans = values1[~np.isnan(values1)] + values2_no_nans = values2[~np.isnan(values2)] plot = bpl.figure( width=plot_size, height=plot_size, @@ -227,7 +227,7 @@ def _panel_refresh(self): plot.xaxis.axis_label = col1 plot.yaxis.axis_label = "Count" # Create histogram - hist, edges = np.histogram(values1, bins=self.settings['num_bins']) + hist, edges = np.histogram(values1_no_nans, bins=self.settings['num_bins']) if len(hist) > 0 and max(hist) > 0: plot.quad( top=hist, bottom=0, left=edges[:-1], right=edges[1:], @@ -238,8 +238,9 @@ def _panel_refresh(self): max_hist = max(hist) for unit_ind, unit_id in self.controller.iter_visible_units(): x = values1[unit_ind] - color = self.get_unit_color(unit_id) - plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8) + if not np.isnan(x): + color = self.get_unit_color(unit_id) + plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8) else: # Off-diagonal - scatter plot plot.xaxis.axis_label = col2 @@ -251,8 +252,8 @@ def _panel_refresh(self): # Plot all points in light color first all_source = ColumnDataSource({ - 'x': values2, - 'y': values1, + 'x': values2_no_nans, + 'y': values1_no_nans, 'color': colors }) plot.scatter('x', 'y', source=all_source, size=8, color='color', alpha=0.5)