|  | 
|  | 1 | +"""Segment Tree Data Structure. | 
|  | 2 | +
 | 
|  | 3 | +A Segment Tree is a binary tree used for storing intervals or segments. | 
|  | 4 | +It allows querying which of the stored segments contain a given point. | 
|  | 5 | +Typically used for range queries and updates. | 
|  | 6 | +
 | 
|  | 7 | +Time Complexity: | 
|  | 8 | +- Build: O(n) | 
|  | 9 | +- Query: O(log n) | 
|  | 10 | +- Update: O(log n) | 
|  | 11 | +
 | 
|  | 12 | +Space Complexity: O(n) | 
|  | 13 | +""" | 
|  | 14 | + | 
|  | 15 | +from typing import Callable | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +class SegmentTree: | 
|  | 19 | + """Segment Tree implementation for range queries. | 
|  | 20 | +
 | 
|  | 21 | + This implementation supports range sum queries and point updates. | 
|  | 22 | + Can be extended to support other operations like min/max queries. | 
|  | 23 | +
 | 
|  | 24 | + Attributes: | 
|  | 25 | + tree: List storing the segment tree nodes | 
|  | 26 | + n: Size of the input array | 
|  | 27 | + operation: Function to combine two values (default: addition) | 
|  | 28 | +
 | 
|  | 29 | + >>> st = SegmentTree([1, 3, 5, 7, 9, 11]) | 
|  | 30 | + >>> st.query(1, 3) | 
|  | 31 | + 15 | 
|  | 32 | + >>> st.update(1, 10) | 
|  | 33 | + >>> st.query(1, 3) | 
|  | 34 | + 22 | 
|  | 35 | + >>> st.query(0, 5) | 
|  | 36 | + 42 | 
|  | 37 | + >>> st2 = SegmentTree([2, 4, 6, 8], operation=min) | 
|  | 38 | + >>> st2.query(0, 3) | 
|  | 39 | + 2 | 
|  | 40 | + >>> st2.update(0, 10) | 
|  | 41 | + >>> st2.query(0, 3) | 
|  | 42 | + 4 | 
|  | 43 | + """ | 
|  | 44 | + | 
|  | 45 | + def __init__( | 
|  | 46 | + self, arr: list[int], operation: Callable[[int, int], int] = lambda a, b: a + b | 
|  | 47 | + ) -> None: | 
|  | 48 | + """Initialize segment tree with given array. | 
|  | 49 | +
 | 
|  | 50 | + Args: | 
|  | 51 | + arr: Input array of integers | 
|  | 52 | + operation: Binary operation to combine values (default: addition) | 
|  | 53 | +
 | 
|  | 54 | + >>> st = SegmentTree([1, 2, 3]) | 
|  | 55 | + >>> len(st.tree) | 
|  | 56 | + 8 | 
|  | 57 | + """ | 
|  | 58 | + self.n = len(arr) | 
|  | 59 | + self.tree = [0] * (4 * self.n) # Allocate space for segment tree | 
|  | 60 | + self.operation = operation | 
|  | 61 | + self._build(arr, 0, 0, self.n - 1) | 
|  | 62 | + | 
|  | 63 | + def _build(self, arr: list[int], node: int, start: int, end: int) -> None: | 
|  | 64 | + """Build segment tree recursively. | 
|  | 65 | +
 | 
|  | 66 | + Args: | 
|  | 67 | + arr: Input array | 
|  | 68 | + node: Current node index in tree | 
|  | 69 | + start: Start index of current segment | 
|  | 70 | + end: End index of current segment | 
|  | 71 | + """ | 
|  | 72 | + if start == end: | 
|  | 73 | + # Leaf node | 
|  | 74 | + self.tree[node] = arr[start] | 
|  | 75 | + else: | 
|  | 76 | + mid = (start + end) // 2 | 
|  | 77 | + left_child = 2 * node + 1 | 
|  | 78 | + right_child = 2 * node + 2 | 
|  | 79 | + self._build(arr, left_child, start, mid) | 
|  | 80 | + self._build(arr, right_child, mid + 1, end) | 
|  | 81 | + self.tree[node] = self.operation( | 
|  | 82 | + self.tree[left_child], self.tree[right_child] | 
|  | 83 | + ) | 
|  | 84 | + | 
|  | 85 | + def query(self, left: int, right: int) -> int: | 
|  | 86 | + """Query for value in range [left, right]. | 
|  | 87 | +
 | 
|  | 88 | + Args: | 
|  | 89 | + left: Left boundary of query range (inclusive) | 
|  | 90 | + right: Right boundary of query range (inclusive) | 
|  | 91 | +
 | 
|  | 92 | + Returns: | 
|  | 93 | + Result of applying operation over the range | 
|  | 94 | +
 | 
|  | 95 | + >>> st = SegmentTree([1, 2, 3, 4, 5]) | 
|  | 96 | + >>> st.query(0, 2) | 
|  | 97 | + 6 | 
|  | 98 | + >>> st.query(2, 4) | 
|  | 99 | + 12 | 
|  | 100 | + """ | 
|  | 101 | + return self._query(0, 0, self.n - 1, left, right) | 
|  | 102 | + | 
|  | 103 | + def _query(self, node: int, start: int, end: int, left: int, right: int) -> int: | 
|  | 104 | + """Recursive helper for range query. | 
|  | 105 | +
 | 
|  | 106 | + Args: | 
|  | 107 | + node: Current node index | 
|  | 108 | + start: Start of current segment | 
|  | 109 | + end: End of current segment | 
|  | 110 | + left: Query left boundary | 
|  | 111 | + right: Query right boundary | 
|  | 112 | +
 | 
|  | 113 | + Returns: | 
|  | 114 | + Query result for current segment | 
|  | 115 | + """ | 
|  | 116 | + if right < start or left > end: | 
|  | 117 | + # No overlap | 
|  | 118 | + return 0 if self.operation(0, 0) == 0 else float('inf') | 
|  | 119 | + | 
|  | 120 | + if left <= start and end <= right: | 
|  | 121 | + # Complete overlap | 
|  | 122 | + return self.tree[node] | 
|  | 123 | + | 
|  | 124 | + # Partial overlap | 
|  | 125 | + mid = (start + end) // 2 | 
|  | 126 | + left_child = 2 * node + 1 | 
|  | 127 | + right_child = 2 * node + 2 | 
|  | 128 | + left_result = self._query(left_child, start, mid, left, right) | 
|  | 129 | + right_result = self._query(right_child, mid + 1, end, left, right) | 
|  | 130 | + return self.operation(left_result, right_result) | 
|  | 131 | + | 
|  | 132 | + def update(self, index: int, value: int) -> None: | 
|  | 133 | + """Update value at given index. | 
|  | 134 | +
 | 
|  | 135 | + Args: | 
|  | 136 | + index: Index to update | 
|  | 137 | + value: New value | 
|  | 138 | +
 | 
|  | 139 | + >>> st = SegmentTree([1, 2, 3, 4, 5]) | 
|  | 140 | + >>> st.query(0, 4) | 
|  | 141 | + 15 | 
|  | 142 | + >>> st.update(2, 10) | 
|  | 143 | + >>> st.query(0, 4) | 
|  | 144 | + 22 | 
|  | 145 | + """ | 
|  | 146 | + self._update(0, 0, self.n - 1, index, value) | 
|  | 147 | + | 
|  | 148 | + def _update(self, node: int, start: int, end: int, index: int, value: int) -> None: | 
|  | 149 | + """Recursive helper for point update. | 
|  | 150 | +
 | 
|  | 151 | + Args: | 
|  | 152 | + node: Current node index | 
|  | 153 | + start: Start of current segment | 
|  | 154 | + end: End of current segment | 
|  | 155 | + index: Index to update | 
|  | 156 | + value: New value | 
|  | 157 | + """ | 
|  | 158 | + if start == end: | 
|  | 159 | + # Leaf node | 
|  | 160 | + self.tree[node] = value | 
|  | 161 | + else: | 
|  | 162 | + mid = (start + end) // 2 | 
|  | 163 | + left_child = 2 * node + 1 | 
|  | 164 | + right_child = 2 * node + 2 | 
|  | 165 | + if index <= mid: | 
|  | 166 | + self._update(left_child, start, mid, index, value) | 
|  | 167 | + else: | 
|  | 168 | + self._update(right_child, mid + 1, end, index, value) | 
|  | 169 | + self.tree[node] = self.operation( | 
|  | 170 | + self.tree[left_child], self.tree[right_child] | 
|  | 171 | + ) | 
|  | 172 | + | 
|  | 173 | + | 
|  | 174 | +if __name__ == "__main__": | 
|  | 175 | + import doctest | 
|  | 176 | + | 
|  | 177 | + doctest.testmod() | 
0 commit comments