Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/manim_data_structures/bs_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Any, Callable, Hashable

import nary_tree as nt
import networkx as nx
from manim import Mobject


class BSTree(nt.NaryTree):

# BSTree is a NaryTree with num_children=2.
# BSTree also keeps track of the values of the nodes in the tree with __values
def __init__(
self,
nodes: dict[int, Any],
vertex_type: Callable[..., Mobject],
edge_buff=0.4,
layout_config={"vertex_spacing": (0.6, -0.6)},
**kwargs
):
self.__values = nodes
super().__init__(nodes, 2, vertex_type, edge_buff, layout_config, **kwargs)

def insert_node(self, node: Any):
"""
Inserts a node into the BST, overriding definition in NaryTree
"""
# find the place to insert based on BST properties, starts at index=0
ind = self.find_insertion_index(node)
# take note of the new node's value
self.__values[ind] = node
# insert the new node into the correct position using NaryTree's implementation
return super().insert_node(node, ind)

def remove_node(self, node: Any):
"""
Removes a node with value 'node' from the BST
"""
ind = -1
# does not support removing the root
for k, v in self.__values.items():
if k != 0 and v == node:
ind = k
break
if ind == -1:
return
removal_ind = ind
# if left child exists
if ind * 2 + 1 in self.__values:
ind = ind * 2 + 1
# find rightmost
while ind * 2 + 2 in self.__values:
ind = ind * 2 + 2
# swap the rightmost down its left
while ind * 2 + 1 in self.__values:
self.__values[ind], self.__values[ind * 2 + 1] = (
self.__values[ind * 2 + 1],
self.__values[ind],
)
ind = ind * 2 + 1
# remove
self.__values[ind], self.__values[removal_ind] = (
self.__values[removal_ind],
self.__values[ind],
)
del self.__values[ind]
self._graph.remove_vertices(ind)
# get rid of graph
for vert in self.__values:
if vert != 0:
self._graph.remove_vertices(vert)
# rebuild; this also makes edges
for k, v in self.__values.items():
if k != 0:
super().insert_node(v, k)
# choose right child if left is not present
elif ind * 2 + 2 in self._graph.vertices:
ind = ind * 2 + 2
while ind * 2 + 1 in self.__values:
ind = ind * 2 + 1
while ind * 2 + 2 in self.__values:
self.__values[ind], self.__values[ind * 2 + 2] = (
self.__values[ind * 2 + 2],
self.__values[ind],
)
ind = ind * 2 + 2
self.__values[ind], self.__values[removal_ind] = (
self.__values[removal_ind],
self.__values[ind],
)
del self.__values[ind]
self._graph.remove_vertices(ind)

for vert in self.__values:
if vert != 0:
self._graph.remove_vertices(vert)
for k, v in self.__values.items():
if k != 0:
super().insert_node(v, k)
# just delete it if it has no children
else:
self._graph.remove_vertices(removal_ind)
del self.__values[removal_ind]

def find_insertion_index(self, node: Any, index=0):
"""
Finds the position where a value 'node' should be placed
"""
# if we reach an index not yet created, we are done
if index not in self._graph.vertices:
return index
# otherwise, try to find position at left or right based on value
if node <= self.__values[index]:
return self.find_insertion_index(node, index * 2 + 1)
else:
return self.find_insertion_index(node, index * 2 + 2)


if __name__ == "__main__":
from manim import *

class TestScene(Scene):
def construct(self):
tree = BSTree({0: 10}, vertex_type=Integer)
self.play(Create(tree))
tree.insert_node(5)
self.wait()
tree.insert_node(15)
self.wait()
tree.insert_node(20)
self.wait()
tree.insert_node(12)
self.wait()
tree.insert_node(14)
self.wait()
tree.insert_node(13)
self.wait()
tree.remove_node(15)
self.wait()
self.wait()

config.preview = True
config.renderer = "cairo"
config.quality = "high_quality"
TestScene().render()
144 changes: 144 additions & 0 deletions src/manim_data_structures/m_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import operator as op
import random
from collections import defaultdict
from copy import copy
from functools import partialmethod, reduce
from typing import Any, Callable, Dict, Hashable, List, Tuple

import numpy as np
from manim import *
from manim import WHITE, Graph, Mobject, VMobject


class Tree(VMobject):
"""Computer Science Tree Data Structure"""

_graph: Graph
__layout_config: dict
__layout_scale: float
__layout: str | dict
__vertex_type: Callable[..., Mobject]

# __parents: list
# __children: dict[Hashable, list] = defaultdict(list)

def __init__(
self,
nodes: dict[int, Any],
edges: list[tuple[int, int]],
vertex_type: Callable[..., Mobject],
edge_buff=0.4,
layout="tree",
layout_config={"vertex_spacing": (-1, 1)},
root_vertex=0,
**kwargs
):
super().__init__(**kwargs)
vertex_mobjects = {k: vertex_type(v) for k, v in nodes.items()}
self.__layout_config = layout_config
self.__layout_scale = len(nodes) * 0.5
self.__layout = layout
self.__vertex_type = vertex_type
self._graph = Graph(
list(nodes),
edges,
vertex_mobjects=vertex_mobjects,
layout=layout,
root_vertex=0,
layout_config=self.__layout_config,
layout_scale=len(nodes) * 0.5,
edge_config={"stroke_width": 1, "stroke_color": WHITE},
)

def update_edges(graph: Graph):
"""Updates edges of graph"""
for (u, v), edge in graph.edges.items():
buff_vec = (
edge_buff
* (graph[u].get_center() - graph[v].get_center())
/ np.linalg.norm(graph[u].get_center() - graph[v].get_center())
)
edge.put_start_and_end_on(
graph[u].get_center() - buff_vec, graph[v].get_center() + buff_vec
)

self._graph.updaters.clear()
self._graph.updaters.append(update_edges)
self.add(self._graph)

def insert_node(self, node: Any, edge: tuple[Hashable, Hashable]):
"""Inserts a node into the graph as (parent, node)"""
self._graph.add_vertices(
edge[1], vertex_mobjects={edge[1]: self.__vertex_type(node)}
)
self._graph.add_edges(edge)
return self

def insert_node2(self, node: Any, edge: tuple[Hashable, Hashable]):
"""Inserts a node into the graph as (parent, node)"""
self._graph.change_layout(
self.__layout,
layout_scale=self.__layout_scale,
layout_config=self.__layout_config,
root_vertex=0,
)
for mob in self.family_members_with_points():
if (mob.get_center() == self._graph[edge[1]].get_center()).all():
mob.points = mob.points.astype("float")
return self

def insert_node3(self, node: Any, edge: tuple[Hashable, Hashable]):
"""Inserts a node into the graph as (parent, node)"""
self.suspend_updating()
self.insert_node(node, edge)
# self.resume_updating()
self.insert_node2(node, edge)

return self

def remove_node(self, node: Hashable):
"""Removes a node from the graph"""
self._graph.remove_vertices(node)

# def insert_node2(self):
# """Shift by the given vectors.
#
# Parameters
# ----------
# vectors
# Vectors to shift by. If multiple vectors are given, they are added
# together.
#
# Returns
# -------
# :class:`Mobject`
# ``self``
#
# See also
# --------
# :meth:`move_to`
# """
#
# total_vector = reduce(op.add, vectors)
# for mob in self.family_members_with_points():
# mob.points = mob.points.astype("float")
# mob.points += total_vector
#
# return self


if __name__ == "__main__":

class TestScene(Scene):
def construct(self):
# make a parent list for a tree
tree = Tree({0: 0, 1: 1, 2: 2, 3: 3}, [(0, 1), (0, 2), (1, 3)], Integer)
self.play(Create(tree))
self.wait()
self.play(tree.animate.insert_node3(4, (2, 4)), run_time=0)
self.wait()

config.preview = True
config.renderer = "cairo"
config.quality = "low_quality"
TestScene().render(preview=True)
Loading