Skip to content

Commit d071d83

Browse files
authored
Add Segment Tree data structure implementation
Implement Segment Tree for range queries and updates.
1 parent c79034c commit d071d83

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""Segment Tree Data Structure.
2+
3+
A Segment Tree is a binary tree used for storing intervals or segments.
4+
It allows querying which of the stored segments contain a given point.
5+
Typically used for range queries and updates.
6+
7+
Time Complexity:
8+
- Build: O(n)
9+
- Query: O(log n)
10+
- Update: O(log n)
11+
12+
Space Complexity: O(n)
13+
"""
14+
15+
from typing import Callable
16+
17+
18+
class SegmentTree:
19+
"""Segment Tree implementation for range queries.
20+
21+
This implementation supports range sum queries and point updates.
22+
Can be extended to support other operations like min/max queries.
23+
24+
Attributes:
25+
tree: List storing the segment tree nodes
26+
n: Size of the input array
27+
operation: Function to combine two values (default: addition)
28+
29+
>>> st = SegmentTree([1, 3, 5, 7, 9, 11])
30+
>>> st.query(1, 3)
31+
15
32+
>>> st.update(1, 10)
33+
>>> st.query(1, 3)
34+
22
35+
>>> st.query(0, 5)
36+
42
37+
>>> st2 = SegmentTree([2, 4, 6, 8], operation=min)
38+
>>> st2.query(0, 3)
39+
2
40+
>>> st2.update(0, 10)
41+
>>> st2.query(0, 3)
42+
4
43+
"""
44+
45+
def __init__(
46+
self, arr: list[int], operation: Callable[[int, int], int] = lambda a, b: a + b
47+
) -> None:
48+
"""Initialize segment tree with given array.
49+
50+
Args:
51+
arr: Input array of integers
52+
operation: Binary operation to combine values (default: addition)
53+
54+
>>> st = SegmentTree([1, 2, 3])
55+
>>> len(st.tree)
56+
8
57+
"""
58+
self.n = len(arr)
59+
self.tree = [0] * (4 * self.n) # Allocate space for segment tree
60+
self.operation = operation
61+
self._build(arr, 0, 0, self.n - 1)
62+
63+
def _build(self, arr: list[int], node: int, start: int, end: int) -> None:
64+
"""Build segment tree recursively.
65+
66+
Args:
67+
arr: Input array
68+
node: Current node index in tree
69+
start: Start index of current segment
70+
end: End index of current segment
71+
"""
72+
if start == end:
73+
# Leaf node
74+
self.tree[node] = arr[start]
75+
else:
76+
mid = (start + end) // 2
77+
left_child = 2 * node + 1
78+
right_child = 2 * node + 2
79+
self._build(arr, left_child, start, mid)
80+
self._build(arr, right_child, mid + 1, end)
81+
self.tree[node] = self.operation(
82+
self.tree[left_child], self.tree[right_child]
83+
)
84+
85+
def query(self, left: int, right: int) -> int:
86+
"""Query for value in range [left, right].
87+
88+
Args:
89+
left: Left boundary of query range (inclusive)
90+
right: Right boundary of query range (inclusive)
91+
92+
Returns:
93+
Result of applying operation over the range
94+
95+
>>> st = SegmentTree([1, 2, 3, 4, 5])
96+
>>> st.query(0, 2)
97+
6
98+
>>> st.query(2, 4)
99+
12
100+
"""
101+
return self._query(0, 0, self.n - 1, left, right)
102+
103+
def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
104+
"""Recursive helper for range query.
105+
106+
Args:
107+
node: Current node index
108+
start: Start of current segment
109+
end: End of current segment
110+
left: Query left boundary
111+
right: Query right boundary
112+
113+
Returns:
114+
Query result for current segment
115+
"""
116+
if right < start or left > end:
117+
# No overlap
118+
return 0 if self.operation(0, 0) == 0 else float('inf')
119+
120+
if left <= start and end <= right:
121+
# Complete overlap
122+
return self.tree[node]
123+
124+
# Partial overlap
125+
mid = (start + end) // 2
126+
left_child = 2 * node + 1
127+
right_child = 2 * node + 2
128+
left_result = self._query(left_child, start, mid, left, right)
129+
right_result = self._query(right_child, mid + 1, end, left, right)
130+
return self.operation(left_result, right_result)
131+
132+
def update(self, index: int, value: int) -> None:
133+
"""Update value at given index.
134+
135+
Args:
136+
index: Index to update
137+
value: New value
138+
139+
>>> st = SegmentTree([1, 2, 3, 4, 5])
140+
>>> st.query(0, 4)
141+
15
142+
>>> st.update(2, 10)
143+
>>> st.query(0, 4)
144+
22
145+
"""
146+
self._update(0, 0, self.n - 1, index, value)
147+
148+
def _update(self, node: int, start: int, end: int, index: int, value: int) -> None:
149+
"""Recursive helper for point update.
150+
151+
Args:
152+
node: Current node index
153+
start: Start of current segment
154+
end: End of current segment
155+
index: Index to update
156+
value: New value
157+
"""
158+
if start == end:
159+
# Leaf node
160+
self.tree[node] = value
161+
else:
162+
mid = (start + end) // 2
163+
left_child = 2 * node + 1
164+
right_child = 2 * node + 2
165+
if index <= mid:
166+
self._update(left_child, start, mid, index, value)
167+
else:
168+
self._update(right_child, mid + 1, end, index, value)
169+
self.tree[node] = self.operation(
170+
self.tree[left_child], self.tree[right_child]
171+
)
172+
173+
174+
if __name__ == "__main__":
175+
import doctest
176+
177+
doctest.testmod()

0 commit comments

Comments
 (0)