Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/editor/view_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .backend_customs import SaveLoadPltLinker, custom_buttons_setup
from .state import State, StateLinker
from .view_manager import Event, ViewManager
from .views.clusters_view import ClusterMainView, AgglomerativeView, DBSCANView
from .views.clusters_view import ClusterMainView, AgglomerativeView, DBSCANView, MergeView, AddView
from .views.home_view import Home
from .views.hulls_view import HullView
from .views.labels_view import LabelsView, ArrowsView
Expand Down Expand Up @@ -35,7 +35,9 @@ def run(self) -> None:
HullView(vm),
ClusterMainView(vm),
AgglomerativeView(vm),
DBSCANView(vm)]) # must be the same as ViewsEnum
DBSCANView(vm),
MergeView(vm),
AddView(vm)]) # must be the same as ViewsEnum
vm.run()

# display
Expand Down
38 changes: 27 additions & 11 deletions src/editor/view_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import logging
from typing import Callable

from matplotlib.figure import Figure
from matplotlib.backend_bases import FigureCanvasBase
from matplotlib.axes._axes import Axes
from matplotlib.widgets import RadioButtons, Slider, Button, TextBox, CheckButtons

from .artists import *
Expand All @@ -18,15 +16,15 @@ class ViewsEnum(Enum):
"""
Enumeration of different views in the application.
"""
HOME = 0
LABELS = 1
ARROWS = 2
HULLS = 3
CLUSTER = 4
AGGLOMERATIVE = 5
DBSCAN = 6
CREATEHULL = 7
REMOVELINE = 8
HOME = 0
LABELS = 1
ARROWS = 2
HULLS = 3
CLUSTER = 4
AGGLOMERATIVE = 5
DBSCAN = 6
MERGE = 7
ADD = 8


class ViewManager:
Expand Down Expand Up @@ -876,6 +874,24 @@ def show(self) -> None:
self.ref.active = True
self.ax.set_visible(True)

def get_index(self, label_name):
# gpt told me to use next, never seen that before
return next((i for i, text in enumerate(self.ref.labels) if text.get_text() == label_name), 0)

def highlight_label(self, label_name, color):
index = self.get_index(label_name)
self.ref.labels[index].set_color(color)
self.ref.labels[index].set_weight("heavy")
self.ref.labels[index].set_alpha(1)

def dehighlight_label(self, label_name):
index = self.get_index(label_name)
self.ref.labels[index].set_color("black")
self.ref.labels[index].set_weight("normal")
self.ref.labels[index].set_alpha(0.5)

def hide_props(self):
self.ref._buttons.set_alpha(0)

class ViewSlider(ViewElement):

Expand Down
287 changes: 286 additions & 1 deletion src/editor/views/clusters_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self, view_manager: ViewManager) -> None:
self.vem.add(ChangeViewButton(self, [0.75, self.change_button_y, self.change_button_length, self.change_button_height], "Agglo", ViewsEnum.AGGLOMERATIVE))
self.vem.add(ChangeViewButton(self, [0.85, self.change_button_y, self.change_button_length, self.change_button_height], "DBSCAN", ViewsEnum.DBSCAN))
self.vem.add(NormalButton(self, [0.05, 0.05, 0.1, 0.075], "Remove", self.remove_point))
self.vem.add(NormalButton(self, [0.17, 0.05, 0.15, 0.075], "Toggle hulls", self.draw_hull))
self.vem.add(NormalButton(self, [0.17, 0.05, 0.1, 0.075], "Add", lambda: self.change_view(ViewsEnum.ADD)))
self.vem.add(NormalButton(self, [0.29, 0.05, 0.1, 0.075], "Merge", lambda: self.change_view(ViewsEnum.MERGE)))
self.vem.add(NormalButton(self, [0.68, 0.05, 0.15, 0.075], "Toggle hulls", self.draw_hull))
reset_b = self.vem.add(NormalButton(self, [0.85, 0.05, 0.1, 0.075], "Reset", self.reset_clusters))
self.info = self.vem.add(ViewText(self.vm.ax, 0, 0, "Info"))

Expand Down Expand Up @@ -156,6 +158,289 @@ def draw_hull(self):
plt.draw()
self.vm.list_manager.clusters_view_hull_off = self.hulls_off

class AddView(View):
def __init__(self, view_manager: ViewManager) -> None:
super().__init__(view_manager)
self.vem.add(ChangeViewButton(self, self.home_ax, "Home", ViewsEnum.HOME))
view_button = self.vem.add(ChangeViewButton(self, self.clusters_ax, "Cluster", ViewsEnum.CLUSTER))
self.vem.add(ChangeViewButton(self, self.labels_ax, "Labels", ViewsEnum.LABELS))
self.vem.add(ChangeViewButton(self, self.hulls_ax, "Hulls", ViewsEnum.HULLS))
view_button.highlight()

self.submitted = False
self.picked_points = []
self.point_artists = {}
self.cluster_name = None
self.widget = None
self.vem.add(NormalButton(self, [0.43, 0.05, 0.1, 0.075], "Back", lambda: self.change_view(ViewsEnum.CLUSTER)))
self.vem.add(NormalButton(self, [0.55, 0.05, 0.15, 0.075], "Submit", self.submit))
reset_b = self.vem.add(NormalButton(self, [0.72, 0.05, 0.1, 0.075], "Reset", self.reset_clusters))
reset_b.button_ax.set_facecolor("lightcoral")
reset_b.button_ref.color = "lightcoral"
reset_b.button_ref.hovercolor = "crimson"
self.vem.hide()

def draw(self, *args, **kwargs) -> None:
super().draw()
self.state.hide_labels_and_hulls(self.vm.ax)
self.vm.list_manager.hide_button()

# make points more transparent
for artist in self.state.data['clusters_data']['artists']:
artist.set_alpha(0.3)

self.cem.add(SharedEvent('pick_event', self.pick_event))

self.widget = self.vem.add(ViewRadioButtons(self, [-0.05, 0.15, 0.3, 0.75],
sorted(list(self.state.get_all_clusters().keys())), self.update, 0))
self.cluster_name = self.widget.ref.value_selected
self.widget.ref.set_label_props(dict(alpha=[0.5]))
self.highlight_cluster(self.cluster_name, "green")
self.widget.highlight_label(self.cluster_name, "darkgreen")
self.widget.hide_props()

# make main plot larger
df = self.state.get_all_points()
# setting lims manually since relim and autoscale don't perform well
self.vm.ax.set_xlim(df['x'].min() - 10, df['x'].max() + 10)
self.vm.ax.set_ylim(df['y'].min() - 10, df['y'].max() + 10)
plt.subplots_adjust(bottom=0.15, left=0.25, right=0.99, top=0.935)

plt.draw()

def submit(self):
self.submitted = True
self.dehighlight_cluster(self.cluster_name)

points = self.state.get_cluster(self.cluster_name).index.tolist()
points.extend(self.picked_points)
self.state.set_cluster(self.cluster_name, points)

self.remove_artists()
self.highlight_cluster(self.cluster_name, "black")
plt.draw()

def update(self, _):
self.dehighlight_cluster(self.cluster_name)
self.widget.dehighlight_label(self.cluster_name)

self.cluster_name = self.widget.ref.value_selected
self.highlight_cluster(self.cluster_name, "green")
self.widget.highlight_label(self.cluster_name, "darkgreen")
plt.draw()

def remove_artists(self):
for artist in self.point_artists.values():
artist.remove()
self.point_artists = {}
self.picked_points = []

def pick_event(self, event: PickEvent) -> None:
if self.submitted:
self.submitted = False
self.highlight_cluster(self.cluster_name, "green")
if event.artist.id in self.picked_points:
self.point_artists[event.artist.id].remove()
self.point_artists.pop(event.artist.id, None)
self.picked_points.remove(event.artist.id)
else:
picked_item = PointArtist.point(self.vm.ax, event.artist.id, facecolor="red", edgecolor="maroon", zorder=20)
self.point_artists[picked_item.id] = picked_item
self.picked_points.append(picked_item.id)
plt.draw()

def reset_clusters(self):
self.dehighlight_cluster(self.cluster_name)
self.state.reset_clusters()
for artist in self.state.data['clusters_data']['artists']:
artist.set_color(self.state.get_point_color(artist.id))
self.hide()
self.draw()
plt.draw()

def highlight_cluster(self, cluster_name, color):
"""Makes currently picked cluster points more visible"""
idx = 0
for point_id in self.state.get_cluster(cluster_name).index:
artist = self.state.data['clusters_data']['artists'][point_id]
artist.set_color(color)
artist.set_alpha(1)
artist.set_radius(2.2)
artist.set_zorder(10)
idx += 1

def dehighlight_cluster(self, cluster_name):
"""Resets currently picked cluster points to their original look"""
for point_id in self.state.get_cluster(cluster_name).index:
artist = self.state.data['clusters_data']['artists'][point_id]
artist.set_color(self.state.get_point_color(point_id))
artist.set_alpha(0.3)
artist.set_radius(1.5)
artist.set_zorder(1)

def hide(self) -> None:
super().hide()
self.vem.remove(self.widget)
self.remove_artists()
self.dehighlight_cluster(self.cluster_name)
for artist in self.state.data['clusters_data']['artists']:
artist.set_alpha(1)
self.vm.ax.set_xlim(-190, 190)
self.vm.ax.set_ylim(-150, 150)
plt.subplots_adjust(bottom=0.15, left=0.01, right=0.99, top=0.935)

class MergeView(View):
def __init__(self, view_manager: ViewManager) -> None:
super().__init__(view_manager)
self.vem.add(ChangeViewButton(self, self.home_ax, "Home", ViewsEnum.HOME))
view_button = self.vem.add(ChangeViewButton(self, self.clusters_ax, "Cluster", ViewsEnum.CLUSTER))
self.vem.add(ChangeViewButton(self, self.labels_ax, "Labels", ViewsEnum.LABELS))
self.vem.add(ChangeViewButton(self, self.hulls_ax, "Hulls", ViewsEnum.HULLS))
view_button.highlight()

self.submitted = False
self.cluster_name_right = None
self.cluster_name_left = None
self.widget_right = None
self.widget_left = None
self.vem.add(NormalButton(self, [0.28, 0.05, 0.1, 0.075], "Back", lambda: self.change_view(ViewsEnum.CLUSTER)))
self.vem.add(NormalButton(self, [0.4, 0.05, 0.15, 0.075], "Submit", self.merge))

reset_b = self.vem.add(NormalButton(self, [0.57, 0.05, 0.1, 0.075], "Reset", self.reset_clusters))
reset_b.button_ax.set_facecolor("lightcoral")
reset_b.button_ref.color = "lightcoral"
reset_b.button_ref.hovercolor = "crimson"

self.vem.hide()

def draw(self, start=True, *args, **kwargs) -> None:
super().draw()
self.state.hide_labels_and_hulls(self.vm.ax)
self.vm.list_manager.hide_button()

# make points more transparent
for artist in self.state.data['clusters_data']['artists']:
artist.set_alpha(0.3)

self.widget_left = self.vem.add(ViewRadioButtons(self, [-0.05, 0.15, 0.3, 0.75],
sorted(list(self.state.get_all_clusters().keys())), self.update_left, 0))
self.widget_left.ref.set_label_props(dict(alpha=[0.5]))
self.widget_left.hide_props()

self.widget_right = self.vem.add(ViewRadioButtons(self, [0.72, 0.15, 0.3, 0.75],
sorted(list(self.state.get_all_clusters().keys())), self.update_right, 1))
self.widget_right.ref.set_label_props(dict(alpha=[0.5]))
self.widget_right.hide_props()

if start:
self.starting_look()

plt.draw()

def starting_look(self):
self.cluster_name_left = self.widget_left.ref.value_selected
self.highlight_cluster(self.cluster_name_left, "green")
self.widget_left.highlight_label(self.cluster_name_left, "darkgreen")

self.cluster_name_right = self.widget_right.ref.value_selected
self.highlight_cluster(self.cluster_name_right, "red")
self.widget_right.highlight_label(self.cluster_name_right, "red")

def merge(self):
if self.cluster_name_right == self.cluster_name_left or self.submitted:
return
self.submitted = True

# dehighlight since changing state will make those names meaningless
self.dehighlight_cluster(self.cluster_name_right)
self.dehighlight_cluster(self.cluster_name_left)

name = self.cluster_name_right + "_" + self.cluster_name_left
points = self.state.get_cluster(self.cluster_name_right).index.tolist()
points.extend(self.state.get_cluster(self.cluster_name_left).index.tolist())
self.state.set_cluster(name, points)

# redrawing since recreating those widgets takes many lines of code
self.hide()
self.draw(False)

# setting names so that update works as supposed
self.cluster_name_right = name
self.cluster_name_left = name
self.widget_left.highlight_label(self.cluster_name_left, "darkgreen")
self.widget_right.highlight_label(self.cluster_name_right, "red")
self.highlight_cluster(name, "black")

plt.draw()

def reset_clusters(self):
self.submitted = False
self.dehighlight_cluster(self.cluster_name_left)
self.dehighlight_cluster(self.cluster_name_right)
self.state.reset_clusters()
for artist in self.state.data['clusters_data']['artists']:
artist.set_color(self.state.get_point_color(artist.id))
self.hide()
self.draw()
plt.draw()

def update_left(self, _):
self.submitted = False
self.dehighlight_cluster(self.cluster_name_left)
self.widget_left.dehighlight_label(self.cluster_name_left)

# when the same cluster was chosen, leave the other color highlight
if self.cluster_name_left == self.cluster_name_right:
self.highlight_cluster(self.cluster_name_right, "red")

self.cluster_name_left = self.widget_left.ref.value_selected
self.highlight_cluster(self.cluster_name_left, "green")
self.widget_left.highlight_label(self.cluster_name_left, "darkgreen")
plt.draw()

def update_right(self, _):
self.submitted = False
self.dehighlight_cluster(self.cluster_name_right)
self.widget_right.dehighlight_label(self.cluster_name_right)

# when the same cluster was chosen, leave the other color highlight
if self.cluster_name_left == self.cluster_name_right:
self.highlight_cluster(self.cluster_name_left, "green")

self.cluster_name_right = self.widget_right.ref.value_selected
self.highlight_cluster(self.cluster_name_right, "red")
self.widget_right.highlight_label(self.cluster_name_right, "red")
plt.draw()

def highlight_cluster(self, cluster_name, color):
"""Makes currently picked cluster points more visible"""
idx = 0
for point_id in self.state.get_cluster(cluster_name).index:
artist = self.state.data['clusters_data']['artists'][point_id]
artist.set_color(color)
artist.set_alpha(1)
artist.set_radius(2.5)
artist.set_zorder(10)
idx += 1

def dehighlight_cluster(self, cluster_name):
"""Resets currently picked cluster points to their original look"""
for point_id in self.state.get_cluster(cluster_name).index:
artist = self.state.data['clusters_data']['artists'][point_id]
artist.set_color(self.state.get_point_color(point_id))
artist.set_alpha(0.3)
artist.set_radius(1.5)
artist.set_zorder(1)

def hide(self) -> None:
super().hide()
self.vem.remove(self.widget_left)
self.vem.remove(self.widget_right)
self.dehighlight_cluster(self.cluster_name_left)
self.dehighlight_cluster(self.cluster_name_right)
for artist in self.state.data['clusters_data']['artists']:
artist.set_alpha(1)


class ClusteringSubViewBase(View):
@abstractmethod
Expand Down
Loading