From 4e0780d8f95c04aca24d81964578a18e015d2bf2 Mon Sep 17 00:00:00 2001 From: EH225 Date: Fri, 22 Aug 2025 13:25:03 -0400 Subject: [PATCH] add SegmentTreeLight --- all_ds.py | 2 +- ds/segment_tree.py | 234 +++++++++++++++++++++++++++++++++- tests/test_data_structures.py | 76 ++++++++++- 3 files changed, 306 insertions(+), 6 deletions(-) diff --git a/all_ds.py b/all_ds.py index 1775f19..1e9e027 100644 --- a/all_ds.py +++ b/all_ds.py @@ -8,7 +8,7 @@ from ds.deque import Deque from ds.binary_search_tree import BinarySearchTree from ds.binary_indexed_tree import BinaryIndexedTree -from ds.segment_tree import SegmentTree +from ds.segment_tree import SegmentTree, SegmentTreeLight from ds.disjoint_sets import DisjointSets from ds.trie import Trie from ds.cache import LRUCache, LFUCache diff --git a/ds/segment_tree.py b/ds/segment_tree.py index 8b1f475..4691d28 100644 --- a/ds/segment_tree.py +++ b/ds/segment_tree.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Segment tree data structure module, see help(SegmentTree) for details. +Segment tree data structure module, see help(SegmentTree) or help(SegmentTreeLight) for details. """ import math @@ -40,7 +40,7 @@ class SegmentTree: Segment trees also support making updates to elements in the array in O(log2(n)) time. This gives them an advantage over prefix sums which require O(n) time to make updates to an arbitrary element in the array. Segment trees also support the usage of monotonic, non-linear operators such as max and min which - prefix sumes do not. + prefix sums do not. Segment trees are well suited to handle a stream of range queries and array updates. They are less well suited to handle insertions or deletions to the array, which requires a rebuild of the segment tree. @@ -292,7 +292,235 @@ def __str__(self) -> str: """ return self.__repr__() - def __len__(self): + def __len__(self) -> int: + """ + Returns the length of the internal array. + """ + return self.n + + +########################## +### Light-Weight Build ### +########################## + +class SegmentTreeLight: + """ + Segment tree data-structure that uses an internal list to represent the data instead of node objects. + Using a python list instead makes this implementation a bit faster than the more extensive class above. + + Segment trees allow for range queries to be executed on an array in O(log2(n)) time which is much faster + than O(n) which is required to traverse the array element by element to evaluate. A range query is some + function e.g. sum, max, min etc. applied over a given range of elements of an array e.g. sum(arr[4:8]). + + Segment trees also support making updates to elements in the array in O(log2(n)) time. This gives them + an advantage over prefix sums which require O(n) time to make updates to an arbitrary element in the + array. Segment trees also support the usage of monotonic, non-linear operators such as max and min which + prefix sums do not. + + Segment trees are well suited to handle a stream of range queries and array updates. They are less well + suited to handle insertions or deletions to the array, which requires a rebuild of the segment tree. + """ + + def __init__(self, arr: List[Union[int, float]], eval_func: str): + """ + Constructor method for SegmentTreeLight data structure. + + Parameters + ---------- + arr : List[Union[int, float]] + An input array of values upon which the segment tree is built. + eval_func : str + The function that will be evaluated over various ranges of arr in the segment tree. + Must be one of the following: ["min", "max", "sum", "gcd", "lcm"]. Note, gcd and lcm are only + appropriate if all values in arr are non-negative integers. + + """ + self.func_dict = {"min": min, "max": max, "sum": lambda x, y: x + y} + self.default_vals = {"min": float("inf"), "max": -float("inf"), "sum": 0} + if eval_func not in self.func_dict: + raise ValueError(f"eval_func must be one of {list(self.func_dict.keys())}") + self.eval_func = eval_func # Record the eval function to be used in this segment tree + self.n = len(arr) # Keep track of the size of the segment tree + self.tree = [self.default_vals[eval_func]] * (4 * self.n) # Store the segment tree values as a list + if self.n > 0: # Construct the segment tree iff arr is non-empty + self._build_tree(arr, 0, 0, self.n - 1) + + def _build_tree(self, arr: List[Union[int, float]], node: int, start: int, end: int) -> None: + """ + Internal helper method that creates the segment tree in self.tree by populating values in the array. + + Parameters + ---------- + arr : List[Union[int, float]] + An input array of values upon which the segment tree is built. + node : int + The index of a given node in self.tree being edited which stores the eval_func computed over the + index range [start, end]. node = 0 accesses the root node. + start : int + The index denoting the start of the segment in arr that node represents e.g. if node == 0 then + start == 0 and this node represents all values in the array. + end : int + The index denoting the end of the segment in arr that node represents e.g. if node == 0 then + end == self.n - 1 and this node represents all values in the array. + """ + if start == end: # If the start and end index are the same, then this is a leaf node + self.tree[node] = arr[start] + return None + + # Recursively build the left and right subtrees for this particular tree node + mid = (start + end) // 2 # Compute the middle index of the interval [start, end] + left_node = 2 * node + 1 # Get the left child node index + right_node = 2 * node + 2 # Get the right child node index + + self._build_tree(arr, left_node, start, mid) # Recursively call to construct the tree + self._build_tree(arr, right_node, mid + 1, end) # Recursively call to construct the tree + + # Internal nodes store the eval_func evaluated on the 2 child nodes + self.tree[node] = self.func_dict[self.eval_func](self.tree[left_node], self.tree[right_node]) + + def range_query(self, start: int, end: int) -> Union[float, int]: + """ + Performs a range query using the internal segment tree and returns the aggregate answer. Computes the + evaluation function over the interval [start, end] of the internal array. Computes and returns the + value of eval_func(arr[start, end + 1]) using the tree's pre-cached values. + + Parameters + ---------- + start : int + The starting index of the range query. + end : int + The ending index of the range query. + + Returns + ------- + Union[float, int] + The eval_func applied to the original arr values starting at start and ending at end. + + """ + assert start >= 0, "start must be >= 0" + assert end <= self.n - 1, "end must be <= n - 1" + assert start <= end, "start must be <= end" + # Pass params to internal helper method starting from the root node i.e. node == 0 + return self._range_query(0, 0, self.n - 1, start, end) + + def _range_query(self, node: int, start: int, end: int, start_q: int, end_q: int) -> Union[int, float]: + """ + Internal helper method used to evaluate range queries. + + Parameters + ---------- + node : int + The index of a given node in self.tree being visited which stores the eval_func over the + index range [start, end] of elements in the original arr. node = 0 accesses the root node. + start : int + The index denoting the start of the segment in arr that node represents e.g. if node == 0 then + start == 0 and this node represents all values in the array. + end : int + The index denoting the end of the segment in arr that node represents e.g. if node == 0 then + end == self.n - 1 and this node represents all values in the array. + start_q : int + The starting index of the range query. + end_q : int + The ending index of the range query. + + Returns + ------- + Union[float, int] + The evaluation of the range query i.e. f(arr[start, end + 1]). + + """ + # Case 1: This segment tree node's coverage in arr is entirely within the query range, return the + # value, we will need all of it i.e. all the elements [start, end] are within [start_q, end_q] + if start_q <= start and end <= end_q: + return self.tree[node] + + # Case 2: This segment tree node's coverage in arr is outside the query range, return a default val + if end < start_q or start > end_q: + return self.default_vals[self.eval_func] + + # Case 3: This segment tree node's coverage in arr is partially overlapping with the query range + mid = (start + end) // 2 + left_node = 2 * node + 1 + right_node = 2 * node + 2 + + # Bisect the current node's range and recursively evaluate + left_val = self._range_query(left_node, start, mid, start_q, end_q) + right_val = self._range_query(right_node, mid + 1, end, start_q, end_q) + + return self.func_dict[self.eval_func](left_val, right_val) + + def update(self, idx: int, val: Union[int, float]) -> None: + """ + Updates the segment tree for the following update to arr: arr[idx] = val + + Parameters + ---------- + idx : int + The index in arr edited by the update operation. + val : Union[int, float] + The new value stored at arr[idx] after the update. + + """ + assert 0 <= idx <= self.n - 1, "idx out of range for segment tree" + self._update(0, 0, self.n - 1, idx, val) # Pass params to internal helper function + + def _update(self, node: int, start: int, end: int, idx: int, val: Union[int, float]) -> None: + """ + Internal helper method used to update the segment tree for the following update to arr: arr[idx] = val + + Parameters + ---------- + node : int + The index of a given node in self.tree being visited which stores the eval_func over the + index range [start, end] of elements in the original arr. node = 0 accesses the root node. + start : int + The index denoting the start of the segment in arr that node represents e.g. if node == 0 then + start == 0 and this node represents all values in the array. + end : int + The index denoting the end of the segment in arr that node represents e.g. if node == 0 then + end == self.n - 1 and this node represents all values in the array. + idx : int + The index in arr edited by the update operation. + val : Union[int, float] + The new value stored at arr[idx] after the update. + + """ + if start == end: # If the start and end index are the same, then this is a leaf node + self.tree[node] = val # Recursion base-case, update assigned value + return None + + # Recursively update the appropriate subtree as a result of this new value + mid = (start + end) // 2 + left_node = 2 * node + 1 + right_node = 2 * node + 2 + + if idx <= mid: # If the update occurred in the left half, then update the left half + self._update(left_node, start, mid, idx, val) + else: # Otherwise if the update occurred in the right half, then update the right half + self._update(right_node, mid + 1, end, idx, val) + + # Update the current node with the eval_func evaluated on the 2 child node values + self.tree[node] = self.func_dict[self.eval_func](self.tree[left_node], self.tree[right_node]) + + def __setitem__(self, idx: int, val: Union[int, float]) -> None: + """ + Support for obj[idx] = val changes to the underlying array and segmentation tree data structure. + """ + self.update(idx, val) + + def __repr__(self) -> str: + """ + String representation of the object, reports the segment tree array and the operation specified. + """ + return str(self.tree) + " " + str(self.eval_func) + + def __str__(self) -> str: + """ + String representation of the object, reports the segment tree array and the operation specified. + """ + return self.__repr__() + + def __len__(self) -> int: """ Returns the length of the internal array. """ diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py index 5e75eb9..17301d7 100644 --- a/tests/test_data_structures.py +++ b/tests/test_data_structures.py @@ -4,8 +4,8 @@ """ from all_ds import BinarySearchTree, BinaryIndexedTree, Deque, DisjointSets, MinHeap, MaxHeap, LinkedList -from all_ds import DoublyLinkedList, SegmentTree, Trie, LRUCache, LFUCache -import pytest +from all_ds import DoublyLinkedList, SegmentTree, SegmentTreeLight, Trie, LRUCache, LFUCache +import pytest, math def test_LinkedList(): @@ -289,10 +289,82 @@ def test_SegmentTree(): assert obj.pop() == test_data.pop() assert obj.range_query(2, 4) == sum(test_data[2:5]), "Failed pop update test" + # Test the other possiable eval functions + err_msg = "Failed range_query test" + obj = SegmentTree(test_data, "max") + for start in range(len(test_data)): # Run tests on all possiable ranges + for end in range(start, len(test_data)): + assert max(test_data[start:end + 1]) == obj.range_query(start, end), err_msg + + obj = SegmentTree(test_data, "min") + for start in range(len(test_data)): # Run tests on all possiable ranges + for end in range(start, len(test_data)): + assert min(test_data[start:end + 1]) == obj.range_query(start, end), err_msg + + test_data2 = [1, 2, 3, 5, 8, 10, 12, 15] + + obj = SegmentTree(test_data2, "gcd") + for start in range(len(test_data2)): # Run tests on all possiable ranges + for end in range(start, len(test_data2)): + assert math.gcd(*test_data2[start:end + 1]) == obj.range_query(start, end), err_msg + + obj = SegmentTree(test_data2, "lcm") + for start in range(len(test_data2)): # Run tests on all possiable ranges + for end in range(start, len(test_data2)): + assert math.lcm(*test_data2[start:end + 1]) == obj.range_query(start, end), err_msg + with pytest.raises(ValueError): # Test selecting an unknown eval function obj = SegmentTree(test_data, "xyz") +def test_SegmentTreeLight(): + """ + Runs basic tests for the SegmentTreeLight data structure, tests methods and functionality. + """ + test_data = [1, 2, 3, 5, 8, -10, 12] + obj = SegmentTreeLight(test_data, "sum") + assert len(obj) == len(test_data), "Failed len(obj) test" + assert isinstance(str(obj), str), "Failed string representation test" + + for start in range(len(test_data)): # Run tests on all possiable ranges + for end in range(start, len(test_data)): + assert sum(test_data[start:end + 1]) == obj.range_query(start, end), "Failed range_query test" + + obj.update(2, 5) + test_data[2] = 5 + assert obj.range_query(1, 3) == sum(test_data[1:4]), "Failed update test" + + obj[2] = -5 + test_data[2] = -5 + assert obj.range_query(0, 4) == sum(test_data[0:5]), "Failed update test" + + obj[2] = 25 + test_data[2] = 25 + assert obj.range_query(2, 3) == sum(test_data[2:4]), "Failed update test" + + obj[2] = 25 + assert obj.range_query(2, 3) == sum(test_data[2:4]), "Failed update test - no change update" + + test_data = [1, 2, 3, 5, 8, -10, 12] + obj = SegmentTreeLight(test_data, "sum") + assert len(obj) == len(test_data), "Failed len(obj) test" + assert isinstance(str(obj), str), "Failed string representation test" + + # Test the other 2 possiable eval functions + obj = SegmentTreeLight(test_data, "max") + for start in range(len(test_data)): # Run tests on all possiable ranges + for end in range(start, len(test_data)): + assert max(test_data[start:end + 1]) == obj.range_query(start, end), "Failed range_query test" + + obj = SegmentTreeLight(test_data, "min") + for start in range(len(test_data)): # Run tests on all possiable ranges + for end in range(start, len(test_data)): + assert min(test_data[start:end + 1]) == obj.range_query(start, end), "Failed range_query test" + + with pytest.raises(ValueError): # Test selecting an unknown eval function + obj = SegmentTreeLight(test_data, "xyz") + + def test_BinarySearchTree(): """ Runs basic tests for the BinarySearchTree data structure, tests methods and functionality.