Segment Tree
- 4 minutes read - 826 words
Notebooks
Segment Tree
This is a practical guide to implement a segment tree and develop the intuition to apply them while approaching algorithmic problems.
- A binary tree that represents a list of elements
- Leaf of this tree, (or in some cases the lowest level of the tree), represents this list
- Every node in this tree contains some information aggregated over node’s children
- Like Sum, Average, Min, Max etc
For this discussion I’ll use a list of number as an example [3,8,6,4,2,5,9,0,7,1]
and the information that nodes will store is going to be the smallest number among node’s children "minimum"
Creating the Tree
Usually the segment tree is constructed as an array. Being a binary tree the ith
node has its children at 2*i+1
and 2*i+2
.
First step in creating the tree is allocating an array big enough to store all potential nodes. Since, the input array becomes the leaf of this tree and it’s easier to calculate the number of nodes required if you know a few basic properties of a binary tree.
- Calculate the length of input sequence
[3,8,6,4,2,5,9,0,7,1]
arr = [3,8,6,4,2,5,9,0,7,1]
N = len(arr)
- If size of array is not a power of
2
pad the array with placeholder to make it so. This makes the segment tree node calculation easier and makes the final tree balances as well.
import math
k = math.ceil(math.log(N, 2))
while len(arr) < pow(2, k):
arr.append(math.inf)
- Now that we know the number of leaf element the total number of nodes can be calculates as
2^(k+1)
node_count = pow(2, k+1)
- Initialize the array with placeholders and that is our segment tree without any elements
segment_tree = [math.inf]*node_count
Building the Tree
- Tree is built recursively
- Each recursion works on a range of input array elements
- Every iteration of this recursion populates one index in the segment tree and return its value
- If the range denotes only one element, i.e. the leaf of the segment tree then the input array value is put there
- Otherwise the recursion partitions the range and merge the result from left and right recursion tree
def build_tree(tree, arr, tree_index, lo, hi):
if (lo == hi):
# leaf node, save the array element
# this is the smallest number
tree[tree_index] = arr[lo]
return tree[tree_index]
left_inx = 2*tree_index+1
right_inx = 2*tree_index+2
mid = (lo + hi) // 2
left = build_tree(tree, arr, left_inx, lo, mid)
right = build_tree(tree, arr, right_inx, mid+1, hi)
# merge the result
tree[tree_index] = min(left, right)
return tree[tree_index]
build_tree(segment_tree, arr, 0, 0, len(arr)-1)
Range Query
Given the current example, the queries will be regarding getting the minimum number in a range. The answer is found by traversing the segment tree recursively and locating the range in question and reading the result from the node. The recursive routine takes the range in query and the current segment tree node which also represents the range of input array it holds. There could be three cases.
- If the range query is disjoint with the current range then return placeholder.
- If the range query is completely within the current range then return the value in the segment tree node.
- If the range query is completely in the left subtree of current node then recursively query left tree
- If the range query is completely in the right subtree of current node then recursively query right tree
- If there is an overlap then query both and return the minimum.
def query(tree, tree_index, lo, hi, i, j):
# print(tree_index, lo, hi, i, j)
if (lo > j or hi < i): # disjoint
return math.inf
if (i <= lo and j >= hi): # completely inside
# print(tree_index, lo, hi, i, j)
return tree[tree_index]
left_inx = 2*tree_index+1
right_inx = 2*tree_index+2
mid = (lo + hi) // 2
if i >= mid+1:
# completely in right half
return query(tree, right_inx, mid+1, hi, i, j)
elif j <= mid:
# completely in left half
return query(tree, left_inx, lo, mid, i, j)
# in case of overlap
left = query(tree, left_inx, lo, mid, i, j)
right = query(tree, right_inx, mid+1, hi, i, j)
return min(left, right)
print("[0, n-1]", query(segment_tree, 0, 0, len(arr)-1, 0, len(arr)-1))
print("[0, 1]", query(segment_tree, 0, 0, len(arr)-1, 0, 1))
print("[2, 6]", query(segment_tree, 0, 0, len(arr)-1, 2, 6))
Updating the Tree
In case the input array changes the segment tree should get updated as well. This is done recursively as well by traversing the path from the leaf to root of the segment tree.
def update(tree, tree_index, lo, hi, arr_index, val):
if (lo == hi):
# Update the leaf
tree[tree_index] = val
return val
left_inx = 2*tree_index+1
right_inx = 2*tree_index+2
mid = (lo + hi) // 2
if arr_index >= mid+1:
# in right half
ch = update(tree, right_inx, mid+1, hi, arr_index, val)
elif arr_index <= mid:
# in left half
ch = update(tree, left_inx, lo, mid, arr_index, val)
tree[tree_index] = min(tree[tree_index], ch)
return tree[tree_index]
arr[5] = 1
update(tree, 0, 0, len(arr)-1, 5, 1)
print("[2, 6]", query(tree, 0, 0, len(arr)-1, 2, 6))