diff --git a/src/manim_data_structures/bs_tree.py b/src/manim_data_structures/bs_tree.py new file mode 100644 index 0000000..4213abe --- /dev/null +++ b/src/manim_data_structures/bs_tree.py @@ -0,0 +1,144 @@ +from typing import Any, Callable, Hashable + +import nary_tree as nt +import networkx as nx +from manim import Mobject + + +class BSTree(nt.NaryTree): + + # BSTree is a NaryTree with num_children=2. + # BSTree also keeps track of the values of the nodes in the tree with __values + def __init__( + self, + nodes: dict[int, Any], + vertex_type: Callable[..., Mobject], + edge_buff=0.4, + layout_config={"vertex_spacing": (0.6, -0.6)}, + **kwargs + ): + self.__values = nodes + super().__init__(nodes, 2, vertex_type, edge_buff, layout_config, **kwargs) + + def insert_node(self, node: Any): + """ + Inserts a node into the BST, overriding definition in NaryTree + """ + # find the place to insert based on BST properties, starts at index=0 + ind = self.find_insertion_index(node) + # take note of the new node's value + self.__values[ind] = node + # insert the new node into the correct position using NaryTree's implementation + return super().insert_node(node, ind) + + def remove_node(self, node: Any): + """ + Removes a node with value 'node' from the BST + """ + ind = -1 + # does not support removing the root + for k, v in self.__values.items(): + if k != 0 and v == node: + ind = k + break + if ind == -1: + return + removal_ind = ind + # if left child exists + if ind * 2 + 1 in self.__values: + ind = ind * 2 + 1 + # find rightmost + while ind * 2 + 2 in self.__values: + ind = ind * 2 + 2 + # swap the rightmost down its left + while ind * 2 + 1 in self.__values: + self.__values[ind], self.__values[ind * 2 + 1] = ( + self.__values[ind * 2 + 1], + self.__values[ind], + ) + ind = ind * 2 + 1 + # remove + self.__values[ind], self.__values[removal_ind] = ( + self.__values[removal_ind], + self.__values[ind], + ) + del self.__values[ind] + self._graph.remove_vertices(ind) + # get rid of graph + for vert in self.__values: + if vert != 0: + self._graph.remove_vertices(vert) + # rebuild; this also makes edges + for k, v in self.__values.items(): + if k != 0: + super().insert_node(v, k) + # choose right child if left is not present + elif ind * 2 + 2 in self._graph.vertices: + ind = ind * 2 + 2 + while ind * 2 + 1 in self.__values: + ind = ind * 2 + 1 + while ind * 2 + 2 in self.__values: + self.__values[ind], self.__values[ind * 2 + 2] = ( + self.__values[ind * 2 + 2], + self.__values[ind], + ) + ind = ind * 2 + 2 + self.__values[ind], self.__values[removal_ind] = ( + self.__values[removal_ind], + self.__values[ind], + ) + del self.__values[ind] + self._graph.remove_vertices(ind) + + for vert in self.__values: + if vert != 0: + self._graph.remove_vertices(vert) + for k, v in self.__values.items(): + if k != 0: + super().insert_node(v, k) + # just delete it if it has no children + else: + self._graph.remove_vertices(removal_ind) + del self.__values[removal_ind] + + def find_insertion_index(self, node: Any, index=0): + """ + Finds the position where a value 'node' should be placed + """ + # if we reach an index not yet created, we are done + if index not in self._graph.vertices: + return index + # otherwise, try to find position at left or right based on value + if node <= self.__values[index]: + return self.find_insertion_index(node, index * 2 + 1) + else: + return self.find_insertion_index(node, index * 2 + 2) + + +if __name__ == "__main__": + from manim import * + + class TestScene(Scene): + def construct(self): + tree = BSTree({0: 10}, vertex_type=Integer) + self.play(Create(tree)) + tree.insert_node(5) + self.wait() + tree.insert_node(15) + self.wait() + tree.insert_node(20) + self.wait() + tree.insert_node(12) + self.wait() + tree.insert_node(14) + self.wait() + tree.insert_node(13) + self.wait() + tree.remove_node(15) + self.wait() + self.wait() + + config.preview = True + config.renderer = "cairo" + config.quality = "high_quality" + TestScene().render() diff --git a/src/manim_data_structures/m_tree.py b/src/manim_data_structures/m_tree.py new file mode 100644 index 0000000..625ac5a --- /dev/null +++ b/src/manim_data_structures/m_tree.py @@ -0,0 +1,144 @@ +import operator as op +import random +from collections import defaultdict +from copy import copy +from functools import partialmethod, reduce +from typing import Any, Callable, Dict, Hashable, List, Tuple + +import numpy as np +from manim import * +from manim import WHITE, Graph, Mobject, VMobject + + +class Tree(VMobject): + """Computer Science Tree Data Structure""" + + _graph: Graph + __layout_config: dict + __layout_scale: float + __layout: str | dict + __vertex_type: Callable[..., Mobject] + + # __parents: list + # __children: dict[Hashable, list] = defaultdict(list) + + def __init__( + self, + nodes: dict[int, Any], + edges: list[tuple[int, int]], + vertex_type: Callable[..., Mobject], + edge_buff=0.4, + layout="tree", + layout_config={"vertex_spacing": (-1, 1)}, + root_vertex=0, + **kwargs + ): + super().__init__(**kwargs) + vertex_mobjects = {k: vertex_type(v) for k, v in nodes.items()} + self.__layout_config = layout_config + self.__layout_scale = len(nodes) * 0.5 + self.__layout = layout + self.__vertex_type = vertex_type + self._graph = Graph( + list(nodes), + edges, + vertex_mobjects=vertex_mobjects, + layout=layout, + root_vertex=0, + layout_config=self.__layout_config, + layout_scale=len(nodes) * 0.5, + edge_config={"stroke_width": 1, "stroke_color": WHITE}, + ) + + def update_edges(graph: Graph): + """Updates edges of graph""" + for (u, v), edge in graph.edges.items(): + buff_vec = ( + edge_buff + * (graph[u].get_center() - graph[v].get_center()) + / np.linalg.norm(graph[u].get_center() - graph[v].get_center()) + ) + edge.put_start_and_end_on( + graph[u].get_center() - buff_vec, graph[v].get_center() + buff_vec + ) + + self._graph.updaters.clear() + self._graph.updaters.append(update_edges) + self.add(self._graph) + + def insert_node(self, node: Any, edge: tuple[Hashable, Hashable]): + """Inserts a node into the graph as (parent, node)""" + self._graph.add_vertices( + edge[1], vertex_mobjects={edge[1]: self.__vertex_type(node)} + ) + self._graph.add_edges(edge) + return self + + def insert_node2(self, node: Any, edge: tuple[Hashable, Hashable]): + """Inserts a node into the graph as (parent, node)""" + self._graph.change_layout( + self.__layout, + layout_scale=self.__layout_scale, + layout_config=self.__layout_config, + root_vertex=0, + ) + for mob in self.family_members_with_points(): + if (mob.get_center() == self._graph[edge[1]].get_center()).all(): + mob.points = mob.points.astype("float") + return self + + def insert_node3(self, node: Any, edge: tuple[Hashable, Hashable]): + """Inserts a node into the graph as (parent, node)""" + self.suspend_updating() + self.insert_node(node, edge) + # self.resume_updating() + self.insert_node2(node, edge) + + return self + + def remove_node(self, node: Hashable): + """Removes a node from the graph""" + self._graph.remove_vertices(node) + + # def insert_node2(self): + # """Shift by the given vectors. + # + # Parameters + # ---------- + # vectors + # Vectors to shift by. If multiple vectors are given, they are added + # together. + # + # Returns + # ------- + # :class:`Mobject` + # ``self`` + # + # See also + # -------- + # :meth:`move_to` + # """ + # + # total_vector = reduce(op.add, vectors) + # for mob in self.family_members_with_points(): + # mob.points = mob.points.astype("float") + # mob.points += total_vector + # + # return self + + +if __name__ == "__main__": + + class TestScene(Scene): + def construct(self): + # make a parent list for a tree + tree = Tree({0: 0, 1: 1, 2: 2, 3: 3}, [(0, 1), (0, 2), (1, 3)], Integer) + self.play(Create(tree)) + self.wait() + self.play(tree.animate.insert_node3(4, (2, 4)), run_time=0) + self.wait() + + config.preview = True + config.renderer = "cairo" + config.quality = "low_quality" + TestScene().render(preview=True) diff --git a/src/manim_data_structures/nary_tree.py b/src/manim_data_structures/nary_tree.py new file mode 100644 index 0000000..a4fc9c4 --- /dev/null +++ b/src/manim_data_structures/nary_tree.py @@ -0,0 +1,119 @@ +from typing import Any, Callable, Hashable + +import networkx as nx +import numpy as np +from m_tree import Tree +from manim import Mobject + + +def _nary_layout( + T: nx.classes.graph.Graph, + vertex_spacing: tuple | None = None, + n: int | None = None, +): + if not n: + raise ValueError("the n-ary tree layout requires the n parameter") + if not nx.is_tree(T): + raise ValueError("The tree layout must be used with trees") + + max_height = NaryTree.calc_loc(max(T), n)[1] + + def calc_pos(x, y): + """ + Scales the coordinates to the desired spacing + """ + return (x - (n**y - 1) / 2) * vertex_spacing[0] * n ** ( + max_height - y + ), y * vertex_spacing[1] + + return { + i: np.array([x, y, 0]) + for i, (x, y) in ((i, calc_pos(*NaryTree.calc_loc(i, n))) for i in T) + } + + +class NaryTree(Tree): + def __init__( + self, + nodes: dict[int, Any], + num_child: int, + vertex_type: Callable[..., Mobject], + edge_buff=0.4, + layout_config=None, + **kwargs + ): + if layout_config is None: + layout_config = {"vertex_spacing": (-1, 1)} + self.__layout_config = layout_config + self.num_child = num_child + + edges = [(self.get_parent(e), e) for e in nodes if e != 0] + super().__init__(nodes, edges, vertex_type, edge_buff, **kwargs) + dict_layout = _nary_layout(self._graph._graph, n=num_child, **layout_config) + self._graph.change_layout(dict_layout) + + @staticmethod + def calc_loc(i, n): + """ + Calculates the coordinates in terms of the shifted level order x position and level height + """ + if n == 1: + return 1, i + 1 + height = int(np.emath.logn(n, i * (n - 1) + 1)) + node_shift = (1 - n**height) // (1 - n) + return i - node_shift, height + + @staticmethod + def calc_idx(loc, n): + """ + Calculates the index from the coordinates + """ + x, y = loc + if n == 1: + return y - 1 + + return int(x + (1 - n**y) // (1 - n)) + + def get_parent(self, idx): + """ + Returns the index of the parent of the node at the given index + """ + x, y = NaryTree.calc_loc(idx, self.num_child) + new_loc = x // self.num_child, y - 1 + return NaryTree.calc_idx(new_loc, self.num_child) + + def insert_node(self, node: Any, index: Hashable): + """Inserts a node into the graph""" + res = super().insert_node(node, (self.get_parent(index), index)) + dict_layout = _nary_layout( + self._graph._graph, n=self.num_child, **self.__layout_config + ) + self._graph.change_layout(dict_layout) + self.update() + return res + + +if __name__ == "__main__": + from manim import * + + class TestScene(Scene): + def construct(self): + tree = NaryTree( + {0: 0, 1: 1, 4: 4}, + num_child=2, + vertex_type=Integer, + layout_config={"vertex_spacing": (1, -1)}, + ) + # tree._graph.change_layout(root_vertex=0, layout_config=tree._Tree__layout_config, + # layout_scale=tree._Tree__layout_scale) + self.play(Create(tree)) + self.wait() + tree.insert_node(3, 3) + self.wait() + tree.remove_node(4) + self.wait() + + config.preview = True + config.renderer = "cairo" + config.quality = "low_quality" + TestScene().render(preview=True)