diff --git a/src/editor/view_editor.py b/src/editor/view_editor.py index 5300527..b7fa466 100644 --- a/src/editor/view_editor.py +++ b/src/editor/view_editor.py @@ -3,7 +3,7 @@ from .backend_customs import SaveLoadPltLinker, custom_buttons_setup from .state import State, StateLinker from .view_manager import Event, ViewManager -from .views.clusters_view import ClusterMainView, AgglomerativeView, DBSCANView +from .views.clusters_view import ClusterMainView, AgglomerativeView, DBSCANView, MergeView, AddView from .views.home_view import Home from .views.hulls_view import HullView from .views.labels_view import LabelsView, ArrowsView @@ -35,7 +35,9 @@ def run(self) -> None: HullView(vm), ClusterMainView(vm), AgglomerativeView(vm), - DBSCANView(vm)]) # must be the same as ViewsEnum + DBSCANView(vm), + MergeView(vm), + AddView(vm)]) # must be the same as ViewsEnum vm.run() # display diff --git a/src/editor/view_manager.py b/src/editor/view_manager.py index 94f3eb2..8935ee2 100644 --- a/src/editor/view_manager.py +++ b/src/editor/view_manager.py @@ -3,9 +3,7 @@ import logging from typing import Callable -from matplotlib.figure import Figure from matplotlib.backend_bases import FigureCanvasBase -from matplotlib.axes._axes import Axes from matplotlib.widgets import RadioButtons, Slider, Button, TextBox, CheckButtons from .artists import * @@ -18,15 +16,15 @@ class ViewsEnum(Enum): """ Enumeration of different views in the application. """ - HOME = 0 - LABELS = 1 - ARROWS = 2 - HULLS = 3 - CLUSTER = 4 - AGGLOMERATIVE = 5 - DBSCAN = 6 - CREATEHULL = 7 - REMOVELINE = 8 + HOME = 0 + LABELS = 1 + ARROWS = 2 + HULLS = 3 + CLUSTER = 4 + AGGLOMERATIVE = 5 + DBSCAN = 6 + MERGE = 7 + ADD = 8 class ViewManager: @@ -876,6 +874,24 @@ def show(self) -> None: self.ref.active = True self.ax.set_visible(True) + def get_index(self, label_name): + # gpt told me to use next, never seen that before + return next((i for i, text in enumerate(self.ref.labels) if text.get_text() == label_name), 0) + + def highlight_label(self, label_name, color): + index = self.get_index(label_name) + self.ref.labels[index].set_color(color) + self.ref.labels[index].set_weight("heavy") + self.ref.labels[index].set_alpha(1) + + def dehighlight_label(self, label_name): + index = self.get_index(label_name) + self.ref.labels[index].set_color("black") + self.ref.labels[index].set_weight("normal") + self.ref.labels[index].set_alpha(0.5) + + def hide_props(self): + self.ref._buttons.set_alpha(0) class ViewSlider(ViewElement): diff --git a/src/editor/views/clusters_view.py b/src/editor/views/clusters_view.py index 1c06e59..12961bb 100644 --- a/src/editor/views/clusters_view.py +++ b/src/editor/views/clusters_view.py @@ -21,7 +21,9 @@ def __init__(self, view_manager: ViewManager) -> None: self.vem.add(ChangeViewButton(self, [0.75, self.change_button_y, self.change_button_length, self.change_button_height], "Agglo", ViewsEnum.AGGLOMERATIVE)) self.vem.add(ChangeViewButton(self, [0.85, self.change_button_y, self.change_button_length, self.change_button_height], "DBSCAN", ViewsEnum.DBSCAN)) self.vem.add(NormalButton(self, [0.05, 0.05, 0.1, 0.075], "Remove", self.remove_point)) - self.vem.add(NormalButton(self, [0.17, 0.05, 0.15, 0.075], "Toggle hulls", self.draw_hull)) + self.vem.add(NormalButton(self, [0.17, 0.05, 0.1, 0.075], "Add", lambda: self.change_view(ViewsEnum.ADD))) + self.vem.add(NormalButton(self, [0.29, 0.05, 0.1, 0.075], "Merge", lambda: self.change_view(ViewsEnum.MERGE))) + self.vem.add(NormalButton(self, [0.68, 0.05, 0.15, 0.075], "Toggle hulls", self.draw_hull)) reset_b = self.vem.add(NormalButton(self, [0.85, 0.05, 0.1, 0.075], "Reset", self.reset_clusters)) self.info = self.vem.add(ViewText(self.vm.ax, 0, 0, "Info")) @@ -156,6 +158,289 @@ def draw_hull(self): plt.draw() self.vm.list_manager.clusters_view_hull_off = self.hulls_off +class AddView(View): + def __init__(self, view_manager: ViewManager) -> None: + super().__init__(view_manager) + self.vem.add(ChangeViewButton(self, self.home_ax, "Home", ViewsEnum.HOME)) + view_button = self.vem.add(ChangeViewButton(self, self.clusters_ax, "Cluster", ViewsEnum.CLUSTER)) + self.vem.add(ChangeViewButton(self, self.labels_ax, "Labels", ViewsEnum.LABELS)) + self.vem.add(ChangeViewButton(self, self.hulls_ax, "Hulls", ViewsEnum.HULLS)) + view_button.highlight() + + self.submitted = False + self.picked_points = [] + self.point_artists = {} + self.cluster_name = None + self.widget = None + self.vem.add(NormalButton(self, [0.43, 0.05, 0.1, 0.075], "Back", lambda: self.change_view(ViewsEnum.CLUSTER))) + self.vem.add(NormalButton(self, [0.55, 0.05, 0.15, 0.075], "Submit", self.submit)) + reset_b = self.vem.add(NormalButton(self, [0.72, 0.05, 0.1, 0.075], "Reset", self.reset_clusters)) + reset_b.button_ax.set_facecolor("lightcoral") + reset_b.button_ref.color = "lightcoral" + reset_b.button_ref.hovercolor = "crimson" + self.vem.hide() + + def draw(self, *args, **kwargs) -> None: + super().draw() + self.state.hide_labels_and_hulls(self.vm.ax) + self.vm.list_manager.hide_button() + + # make points more transparent + for artist in self.state.data['clusters_data']['artists']: + artist.set_alpha(0.3) + + self.cem.add(SharedEvent('pick_event', self.pick_event)) + + self.widget = self.vem.add(ViewRadioButtons(self, [-0.05, 0.15, 0.3, 0.75], + sorted(list(self.state.get_all_clusters().keys())), self.update, 0)) + self.cluster_name = self.widget.ref.value_selected + self.widget.ref.set_label_props(dict(alpha=[0.5])) + self.highlight_cluster(self.cluster_name, "green") + self.widget.highlight_label(self.cluster_name, "darkgreen") + self.widget.hide_props() + + # make main plot larger + df = self.state.get_all_points() + # setting lims manually since relim and autoscale don't perform well + self.vm.ax.set_xlim(df['x'].min() - 10, df['x'].max() + 10) + self.vm.ax.set_ylim(df['y'].min() - 10, df['y'].max() + 10) + plt.subplots_adjust(bottom=0.15, left=0.25, right=0.99, top=0.935) + + plt.draw() + + def submit(self): + self.submitted = True + self.dehighlight_cluster(self.cluster_name) + + points = self.state.get_cluster(self.cluster_name).index.tolist() + points.extend(self.picked_points) + self.state.set_cluster(self.cluster_name, points) + + self.remove_artists() + self.highlight_cluster(self.cluster_name, "black") + plt.draw() + + def update(self, _): + self.dehighlight_cluster(self.cluster_name) + self.widget.dehighlight_label(self.cluster_name) + + self.cluster_name = self.widget.ref.value_selected + self.highlight_cluster(self.cluster_name, "green") + self.widget.highlight_label(self.cluster_name, "darkgreen") + plt.draw() + + def remove_artists(self): + for artist in self.point_artists.values(): + artist.remove() + self.point_artists = {} + self.picked_points = [] + + def pick_event(self, event: PickEvent) -> None: + if self.submitted: + self.submitted = False + self.highlight_cluster(self.cluster_name, "green") + if event.artist.id in self.picked_points: + self.point_artists[event.artist.id].remove() + self.point_artists.pop(event.artist.id, None) + self.picked_points.remove(event.artist.id) + else: + picked_item = PointArtist.point(self.vm.ax, event.artist.id, facecolor="red", edgecolor="maroon", zorder=20) + self.point_artists[picked_item.id] = picked_item + self.picked_points.append(picked_item.id) + plt.draw() + + def reset_clusters(self): + self.dehighlight_cluster(self.cluster_name) + self.state.reset_clusters() + for artist in self.state.data['clusters_data']['artists']: + artist.set_color(self.state.get_point_color(artist.id)) + self.hide() + self.draw() + plt.draw() + + def highlight_cluster(self, cluster_name, color): + """Makes currently picked cluster points more visible""" + idx = 0 + for point_id in self.state.get_cluster(cluster_name).index: + artist = self.state.data['clusters_data']['artists'][point_id] + artist.set_color(color) + artist.set_alpha(1) + artist.set_radius(2.2) + artist.set_zorder(10) + idx += 1 + + def dehighlight_cluster(self, cluster_name): + """Resets currently picked cluster points to their original look""" + for point_id in self.state.get_cluster(cluster_name).index: + artist = self.state.data['clusters_data']['artists'][point_id] + artist.set_color(self.state.get_point_color(point_id)) + artist.set_alpha(0.3) + artist.set_radius(1.5) + artist.set_zorder(1) + + def hide(self) -> None: + super().hide() + self.vem.remove(self.widget) + self.remove_artists() + self.dehighlight_cluster(self.cluster_name) + for artist in self.state.data['clusters_data']['artists']: + artist.set_alpha(1) + self.vm.ax.set_xlim(-190, 190) + self.vm.ax.set_ylim(-150, 150) + plt.subplots_adjust(bottom=0.15, left=0.01, right=0.99, top=0.935) + +class MergeView(View): + def __init__(self, view_manager: ViewManager) -> None: + super().__init__(view_manager) + self.vem.add(ChangeViewButton(self, self.home_ax, "Home", ViewsEnum.HOME)) + view_button = self.vem.add(ChangeViewButton(self, self.clusters_ax, "Cluster", ViewsEnum.CLUSTER)) + self.vem.add(ChangeViewButton(self, self.labels_ax, "Labels", ViewsEnum.LABELS)) + self.vem.add(ChangeViewButton(self, self.hulls_ax, "Hulls", ViewsEnum.HULLS)) + view_button.highlight() + + self.submitted = False + self.cluster_name_right = None + self.cluster_name_left = None + self.widget_right = None + self.widget_left = None + self.vem.add(NormalButton(self, [0.28, 0.05, 0.1, 0.075], "Back", lambda: self.change_view(ViewsEnum.CLUSTER))) + self.vem.add(NormalButton(self, [0.4, 0.05, 0.15, 0.075], "Submit", self.merge)) + + reset_b = self.vem.add(NormalButton(self, [0.57, 0.05, 0.1, 0.075], "Reset", self.reset_clusters)) + reset_b.button_ax.set_facecolor("lightcoral") + reset_b.button_ref.color = "lightcoral" + reset_b.button_ref.hovercolor = "crimson" + + self.vem.hide() + + def draw(self, start=True, *args, **kwargs) -> None: + super().draw() + self.state.hide_labels_and_hulls(self.vm.ax) + self.vm.list_manager.hide_button() + + # make points more transparent + for artist in self.state.data['clusters_data']['artists']: + artist.set_alpha(0.3) + + self.widget_left = self.vem.add(ViewRadioButtons(self, [-0.05, 0.15, 0.3, 0.75], + sorted(list(self.state.get_all_clusters().keys())), self.update_left, 0)) + self.widget_left.ref.set_label_props(dict(alpha=[0.5])) + self.widget_left.hide_props() + + self.widget_right = self.vem.add(ViewRadioButtons(self, [0.72, 0.15, 0.3, 0.75], + sorted(list(self.state.get_all_clusters().keys())), self.update_right, 1)) + self.widget_right.ref.set_label_props(dict(alpha=[0.5])) + self.widget_right.hide_props() + + if start: + self.starting_look() + + plt.draw() + + def starting_look(self): + self.cluster_name_left = self.widget_left.ref.value_selected + self.highlight_cluster(self.cluster_name_left, "green") + self.widget_left.highlight_label(self.cluster_name_left, "darkgreen") + + self.cluster_name_right = self.widget_right.ref.value_selected + self.highlight_cluster(self.cluster_name_right, "red") + self.widget_right.highlight_label(self.cluster_name_right, "red") + + def merge(self): + if self.cluster_name_right == self.cluster_name_left or self.submitted: + return + self.submitted = True + + # dehighlight since changing state will make those names meaningless + self.dehighlight_cluster(self.cluster_name_right) + self.dehighlight_cluster(self.cluster_name_left) + + name = self.cluster_name_right + "_" + self.cluster_name_left + points = self.state.get_cluster(self.cluster_name_right).index.tolist() + points.extend(self.state.get_cluster(self.cluster_name_left).index.tolist()) + self.state.set_cluster(name, points) + + # redrawing since recreating those widgets takes many lines of code + self.hide() + self.draw(False) + + # setting names so that update works as supposed + self.cluster_name_right = name + self.cluster_name_left = name + self.widget_left.highlight_label(self.cluster_name_left, "darkgreen") + self.widget_right.highlight_label(self.cluster_name_right, "red") + self.highlight_cluster(name, "black") + + plt.draw() + + def reset_clusters(self): + self.submitted = False + self.dehighlight_cluster(self.cluster_name_left) + self.dehighlight_cluster(self.cluster_name_right) + self.state.reset_clusters() + for artist in self.state.data['clusters_data']['artists']: + artist.set_color(self.state.get_point_color(artist.id)) + self.hide() + self.draw() + plt.draw() + + def update_left(self, _): + self.submitted = False + self.dehighlight_cluster(self.cluster_name_left) + self.widget_left.dehighlight_label(self.cluster_name_left) + + # when the same cluster was chosen, leave the other color highlight + if self.cluster_name_left == self.cluster_name_right: + self.highlight_cluster(self.cluster_name_right, "red") + + self.cluster_name_left = self.widget_left.ref.value_selected + self.highlight_cluster(self.cluster_name_left, "green") + self.widget_left.highlight_label(self.cluster_name_left, "darkgreen") + plt.draw() + + def update_right(self, _): + self.submitted = False + self.dehighlight_cluster(self.cluster_name_right) + self.widget_right.dehighlight_label(self.cluster_name_right) + + # when the same cluster was chosen, leave the other color highlight + if self.cluster_name_left == self.cluster_name_right: + self.highlight_cluster(self.cluster_name_left, "green") + + self.cluster_name_right = self.widget_right.ref.value_selected + self.highlight_cluster(self.cluster_name_right, "red") + self.widget_right.highlight_label(self.cluster_name_right, "red") + plt.draw() + + def highlight_cluster(self, cluster_name, color): + """Makes currently picked cluster points more visible""" + idx = 0 + for point_id in self.state.get_cluster(cluster_name).index: + artist = self.state.data['clusters_data']['artists'][point_id] + artist.set_color(color) + artist.set_alpha(1) + artist.set_radius(2.5) + artist.set_zorder(10) + idx += 1 + + def dehighlight_cluster(self, cluster_name): + """Resets currently picked cluster points to their original look""" + for point_id in self.state.get_cluster(cluster_name).index: + artist = self.state.data['clusters_data']['artists'][point_id] + artist.set_color(self.state.get_point_color(point_id)) + artist.set_alpha(0.3) + artist.set_radius(1.5) + artist.set_zorder(1) + + def hide(self) -> None: + super().hide() + self.vem.remove(self.widget_left) + self.vem.remove(self.widget_right) + self.dehighlight_cluster(self.cluster_name_left) + self.dehighlight_cluster(self.cluster_name_right) + for artist in self.state.data['clusters_data']['artists']: + artist.set_alpha(1) + class ClusteringSubViewBase(View): @abstractmethod