Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions spikeinterface_gui/metricsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _qt_creat_grid(self):

def _qt_refresh(self):
import pyqtgraph as pg
import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what ?

from .myqt import QT


Expand All @@ -111,27 +112,25 @@ def _qt_refresh(self):

scatter.setData(x=values2, y=values1)

visible_unit_ids = self.controller.get_visible_unit_ids()
visible_unit_ids = self.controller.get_visible_unit_indices()

for unit_ind, unit_id in self.controller.iter_visible_units():
color = self.get_unit_color(unit_id)
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)
if (not pd.isna(values2[unit_ind])) and (not pd.isna(values1[unit_ind])):
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)

# self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color)
# self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
elif c == r:
values1 = units_table[visible_metrics[r]].values
values1_no_nans = values1[~np.isnan(values1)]

count, bins = np.histogram(values1, bins=self.settings['num_bins'])
count, bins = np.histogram(values1_no_nans, bins=self.settings['num_bins'])
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=white_brush, pen=white_brush)
plot.addItem(curve)

for unit_ind, unit_id in self.controller.iter_visible_units():
x = values1[unit_ind]
color = self.get_unit_color(unit_id)
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
plot.addItem(line)
if not pd.isna(x):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuelgarcia when metrics is read, nans are read in as some weird pandas type, which we check here. That's why we need to import pandas. If the user has got quality metrics loaded, they already have used pandas in the unit_tables creation.

line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
plot.addItem(line)

def _qt_select_metrics(self):
if not self.tree_visible_metrics.isVisible():
Expand Down Expand Up @@ -187,6 +186,7 @@ def _panel_on_metrics_changed(self, event):
self.refresh()

def _panel_refresh(self):
import pandas as pd
import panel as pn
import bokeh.plotting as bpl
from bokeh.layouts import gridplot
Expand All @@ -212,6 +212,8 @@ def _panel_refresh(self):
col2 = visible_metrics[c]
values1 = units_table[col1].values
values2 = units_table[col2].values
values1_no_nans = values1[~np.isnan(values1)]
values2_no_nans = values2[~np.isnan(values2)]

plot = bpl.figure(
width=plot_size, height=plot_size,
Expand All @@ -227,7 +229,7 @@ def _panel_refresh(self):
plot.xaxis.axis_label = col1
plot.yaxis.axis_label = "Count"
# Create histogram
hist, edges = np.histogram(values1, bins=self.settings['num_bins'])
hist, edges = np.histogram(values1_no_nans, bins=self.settings['num_bins'])
if len(hist) > 0 and max(hist) > 0:
plot.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:],
Expand All @@ -238,8 +240,9 @@ def _panel_refresh(self):
max_hist = max(hist)
for unit_ind, unit_id in self.controller.iter_visible_units():
x = values1[unit_ind]
color = self.get_unit_color(unit_id)
plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8)
if not pd.isna(x):
color = self.get_unit_color(unit_id)
plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8)
else:
# Off-diagonal - scatter plot
plot.xaxis.axis_label = col2
Expand All @@ -251,8 +254,8 @@ def _panel_refresh(self):

# Plot all points in light color first
all_source = ColumnDataSource({
'x': values2,
'y': values1,
'x': values2_no_nans,
'y': values1_no_nans,
'color': colors
})
plot.scatter('x', 'y', source=all_source, size=8, color='color', alpha=0.5)
Expand Down