Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion all_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ds.deque import Deque
from ds.binary_search_tree import BinarySearchTree
from ds.binary_indexed_tree import BinaryIndexedTree
from ds.segment_tree import SegmentTree
from ds.segment_tree import SegmentTree, SegmentTreeLight
from ds.disjoint_sets import DisjointSets
from ds.trie import Trie
from ds.cache import LRUCache, LFUCache
234 changes: 231 additions & 3 deletions ds/segment_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Segment tree data structure module, see help(SegmentTree) for details.
Segment tree data structure module, see help(SegmentTree) or help(SegmentTreeLight) for details.
"""

import math
Expand Down Expand Up @@ -40,7 +40,7 @@ class SegmentTree:
Segment trees also support making updates to elements in the array in O(log2(n)) time. This gives them
an advantage over prefix sums which require O(n) time to make updates to an arbitrary element in the
array. Segment trees also support the usage of monotonic, non-linear operators such as max and min which
prefix sumes do not.
prefix sums do not.

Segment trees are well suited to handle a stream of range queries and array updates. They are less well
suited to handle insertions or deletions to the array, which requires a rebuild of the segment tree.
Expand Down Expand Up @@ -292,7 +292,235 @@ def __str__(self) -> str:
"""
return self.__repr__()

def __len__(self):
def __len__(self) -> int:
"""
Returns the length of the internal array.
"""
return self.n


##########################
### Light-Weight Build ###
##########################

class SegmentTreeLight:
"""
Segment tree data-structure that uses an internal list to represent the data instead of node objects.
Using a python list instead makes this implementation a bit faster than the more extensive class above.

Segment trees allow for range queries to be executed on an array in O(log2(n)) time which is much faster
than O(n) which is required to traverse the array element by element to evaluate. A range query is some
function e.g. sum, max, min etc. applied over a given range of elements of an array e.g. sum(arr[4:8]).

Segment trees also support making updates to elements in the array in O(log2(n)) time. This gives them
an advantage over prefix sums which require O(n) time to make updates to an arbitrary element in the
array. Segment trees also support the usage of monotonic, non-linear operators such as max and min which
prefix sums do not.

Segment trees are well suited to handle a stream of range queries and array updates. They are less well
suited to handle insertions or deletions to the array, which requires a rebuild of the segment tree.
"""

def __init__(self, arr: List[Union[int, float]], eval_func: str):
"""
Constructor method for SegmentTreeLight data structure.

Parameters
----------
arr : List[Union[int, float]]
An input array of values upon which the segment tree is built.
eval_func : str
The function that will be evaluated over various ranges of arr in the segment tree.
Must be one of the following: ["min", "max", "sum", "gcd", "lcm"]. Note, gcd and lcm are only
appropriate if all values in arr are non-negative integers.

"""
self.func_dict = {"min": min, "max": max, "sum": lambda x, y: x + y}
self.default_vals = {"min": float("inf"), "max": -float("inf"), "sum": 0}
if eval_func not in self.func_dict:
raise ValueError(f"eval_func must be one of {list(self.func_dict.keys())}")
self.eval_func = eval_func # Record the eval function to be used in this segment tree
self.n = len(arr) # Keep track of the size of the segment tree
self.tree = [self.default_vals[eval_func]] * (4 * self.n) # Store the segment tree values as a list
if self.n > 0: # Construct the segment tree iff arr is non-empty
self._build_tree(arr, 0, 0, self.n - 1)

def _build_tree(self, arr: List[Union[int, float]], node: int, start: int, end: int) -> None:
"""
Internal helper method that creates the segment tree in self.tree by populating values in the array.

Parameters
----------
arr : List[Union[int, float]]
An input array of values upon which the segment tree is built.
node : int
The index of a given node in self.tree being edited which stores the eval_func computed over the
index range [start, end]. node = 0 accesses the root node.
start : int
The index denoting the start of the segment in arr that node represents e.g. if node == 0 then
start == 0 and this node represents all values in the array.
end : int
The index denoting the end of the segment in arr that node represents e.g. if node == 0 then
end == self.n - 1 and this node represents all values in the array.
"""
if start == end: # If the start and end index are the same, then this is a leaf node
self.tree[node] = arr[start]
return None

# Recursively build the left and right subtrees for this particular tree node
mid = (start + end) // 2 # Compute the middle index of the interval [start, end]
left_node = 2 * node + 1 # Get the left child node index
right_node = 2 * node + 2 # Get the right child node index

self._build_tree(arr, left_node, start, mid) # Recursively call to construct the tree
self._build_tree(arr, right_node, mid + 1, end) # Recursively call to construct the tree

# Internal nodes store the eval_func evaluated on the 2 child nodes
self.tree[node] = self.func_dict[self.eval_func](self.tree[left_node], self.tree[right_node])

def range_query(self, start: int, end: int) -> Union[float, int]:
"""
Performs a range query using the internal segment tree and returns the aggregate answer. Computes the
evaluation function over the interval [start, end] of the internal array. Computes and returns the
value of eval_func(arr[start, end + 1]) using the tree's pre-cached values.

Parameters
----------
start : int
The starting index of the range query.
end : int
The ending index of the range query.

Returns
-------
Union[float, int]
The eval_func applied to the original arr values starting at start and ending at end.

"""
assert start >= 0, "start must be >= 0"
assert end <= self.n - 1, "end must be <= n - 1"
assert start <= end, "start must be <= end"
# Pass params to internal helper method starting from the root node i.e. node == 0
return self._range_query(0, 0, self.n - 1, start, end)

def _range_query(self, node: int, start: int, end: int, start_q: int, end_q: int) -> Union[int, float]:
"""
Internal helper method used to evaluate range queries.

Parameters
----------
node : int
The index of a given node in self.tree being visited which stores the eval_func over the
index range [start, end] of elements in the original arr. node = 0 accesses the root node.
start : int
The index denoting the start of the segment in arr that node represents e.g. if node == 0 then
start == 0 and this node represents all values in the array.
end : int
The index denoting the end of the segment in arr that node represents e.g. if node == 0 then
end == self.n - 1 and this node represents all values in the array.
start_q : int
The starting index of the range query.
end_q : int
The ending index of the range query.

Returns
-------
Union[float, int]
The evaluation of the range query i.e. f(arr[start, end + 1]).

"""
# Case 1: This segment tree node's coverage in arr is entirely within the query range, return the
# value, we will need all of it i.e. all the elements [start, end] are within [start_q, end_q]
if start_q <= start and end <= end_q:
return self.tree[node]

# Case 2: This segment tree node's coverage in arr is outside the query range, return a default val
if end < start_q or start > end_q:
return self.default_vals[self.eval_func]

# Case 3: This segment tree node's coverage in arr is partially overlapping with the query range
mid = (start + end) // 2
left_node = 2 * node + 1
right_node = 2 * node + 2

# Bisect the current node's range and recursively evaluate
left_val = self._range_query(left_node, start, mid, start_q, end_q)
right_val = self._range_query(right_node, mid + 1, end, start_q, end_q)

return self.func_dict[self.eval_func](left_val, right_val)

def update(self, idx: int, val: Union[int, float]) -> None:
"""
Updates the segment tree for the following update to arr: arr[idx] = val

Parameters
----------
idx : int
The index in arr edited by the update operation.
val : Union[int, float]
The new value stored at arr[idx] after the update.

"""
assert 0 <= idx <= self.n - 1, "idx out of range for segment tree"
self._update(0, 0, self.n - 1, idx, val) # Pass params to internal helper function

def _update(self, node: int, start: int, end: int, idx: int, val: Union[int, float]) -> None:
"""
Internal helper method used to update the segment tree for the following update to arr: arr[idx] = val

Parameters
----------
node : int
The index of a given node in self.tree being visited which stores the eval_func over the
index range [start, end] of elements in the original arr. node = 0 accesses the root node.
start : int
The index denoting the start of the segment in arr that node represents e.g. if node == 0 then
start == 0 and this node represents all values in the array.
end : int
The index denoting the end of the segment in arr that node represents e.g. if node == 0 then
end == self.n - 1 and this node represents all values in the array.
idx : int
The index in arr edited by the update operation.
val : Union[int, float]
The new value stored at arr[idx] after the update.

"""
if start == end: # If the start and end index are the same, then this is a leaf node
self.tree[node] = val # Recursion base-case, update assigned value
return None

# Recursively update the appropriate subtree as a result of this new value
mid = (start + end) // 2
left_node = 2 * node + 1
right_node = 2 * node + 2

if idx <= mid: # If the update occurred in the left half, then update the left half
self._update(left_node, start, mid, idx, val)
else: # Otherwise if the update occurred in the right half, then update the right half
self._update(right_node, mid + 1, end, idx, val)

# Update the current node with the eval_func evaluated on the 2 child node values
self.tree[node] = self.func_dict[self.eval_func](self.tree[left_node], self.tree[right_node])

def __setitem__(self, idx: int, val: Union[int, float]) -> None:
"""
Support for obj[idx] = val changes to the underlying array and segmentation tree data structure.
"""
self.update(idx, val)

def __repr__(self) -> str:
"""
String representation of the object, reports the segment tree array and the operation specified.
"""
return str(self.tree) + " " + str(self.eval_func)

def __str__(self) -> str:
"""
String representation of the object, reports the segment tree array and the operation specified.
"""
return self.__repr__()

def __len__(self) -> int:
"""
Returns the length of the internal array.
"""
Expand Down
76 changes: 74 additions & 2 deletions tests/test_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

from all_ds import BinarySearchTree, BinaryIndexedTree, Deque, DisjointSets, MinHeap, MaxHeap, LinkedList
from all_ds import DoublyLinkedList, SegmentTree, Trie, LRUCache, LFUCache
import pytest
from all_ds import DoublyLinkedList, SegmentTree, SegmentTreeLight, Trie, LRUCache, LFUCache
import pytest, math


def test_LinkedList():
Expand Down Expand Up @@ -289,10 +289,82 @@ def test_SegmentTree():
assert obj.pop() == test_data.pop()
assert obj.range_query(2, 4) == sum(test_data[2:5]), "Failed pop update test"

# Test the other possiable eval functions
err_msg = "Failed range_query test"
obj = SegmentTree(test_data, "max")
for start in range(len(test_data)): # Run tests on all possiable ranges
for end in range(start, len(test_data)):
assert max(test_data[start:end + 1]) == obj.range_query(start, end), err_msg

obj = SegmentTree(test_data, "min")
for start in range(len(test_data)): # Run tests on all possiable ranges
for end in range(start, len(test_data)):
assert min(test_data[start:end + 1]) == obj.range_query(start, end), err_msg

test_data2 = [1, 2, 3, 5, 8, 10, 12, 15]

obj = SegmentTree(test_data2, "gcd")
for start in range(len(test_data2)): # Run tests on all possiable ranges
for end in range(start, len(test_data2)):
assert math.gcd(*test_data2[start:end + 1]) == obj.range_query(start, end), err_msg

obj = SegmentTree(test_data2, "lcm")
for start in range(len(test_data2)): # Run tests on all possiable ranges
for end in range(start, len(test_data2)):
assert math.lcm(*test_data2[start:end + 1]) == obj.range_query(start, end), err_msg

with pytest.raises(ValueError): # Test selecting an unknown eval function
obj = SegmentTree(test_data, "xyz")


def test_SegmentTreeLight():
"""
Runs basic tests for the SegmentTreeLight data structure, tests methods and functionality.
"""
test_data = [1, 2, 3, 5, 8, -10, 12]
obj = SegmentTreeLight(test_data, "sum")
assert len(obj) == len(test_data), "Failed len(obj) test"
assert isinstance(str(obj), str), "Failed string representation test"

for start in range(len(test_data)): # Run tests on all possiable ranges
for end in range(start, len(test_data)):
assert sum(test_data[start:end + 1]) == obj.range_query(start, end), "Failed range_query test"

obj.update(2, 5)
test_data[2] = 5
assert obj.range_query(1, 3) == sum(test_data[1:4]), "Failed update test"

obj[2] = -5
test_data[2] = -5
assert obj.range_query(0, 4) == sum(test_data[0:5]), "Failed update test"

obj[2] = 25
test_data[2] = 25
assert obj.range_query(2, 3) == sum(test_data[2:4]), "Failed update test"

obj[2] = 25
assert obj.range_query(2, 3) == sum(test_data[2:4]), "Failed update test - no change update"

test_data = [1, 2, 3, 5, 8, -10, 12]
obj = SegmentTreeLight(test_data, "sum")
assert len(obj) == len(test_data), "Failed len(obj) test"
assert isinstance(str(obj), str), "Failed string representation test"

# Test the other 2 possiable eval functions
obj = SegmentTreeLight(test_data, "max")
for start in range(len(test_data)): # Run tests on all possiable ranges
for end in range(start, len(test_data)):
assert max(test_data[start:end + 1]) == obj.range_query(start, end), "Failed range_query test"

obj = SegmentTreeLight(test_data, "min")
for start in range(len(test_data)): # Run tests on all possiable ranges
for end in range(start, len(test_data)):
assert min(test_data[start:end + 1]) == obj.range_query(start, end), "Failed range_query test"

with pytest.raises(ValueError): # Test selecting an unknown eval function
obj = SegmentTreeLight(test_data, "xyz")


def test_BinarySearchTree():
"""
Runs basic tests for the BinarySearchTree data structure, tests methods and functionality.
Expand Down