1"""
2AVL Tree - Self-Balancing Binary Search Tree
3
4An AVL tree maintains balance by ensuring that for every node, the height
5difference between left and right subtrees (balance factor) is in [-1, 1].
6Rebalancing is performed via rotations after insertions and deletions.
7
8Time Complexity (where h = O(log n) due to balancing):
9 - add: O(log n)
10 - search: O(log n)
11 - delete: O(log n)
12 - find_successor: O(log n)
13 - traverse: O(n)
14
15Space Complexity:
16 - Storage: O(n)
17 - Operations: O(log n) recursion stack
18"""
19
20from enum import Enum
21
22
23class AvlNode:
24 """A node in the AVL tree with height tracking."""
25
26 __slots__ = ("val", "left", "right", "height")
27
28 def __init__(self, val):
29 self.val = val
30 self.left = None
31 self.right = None
32 self.height = 1
33
34 def __str__(self):
35 return str(self.val)
36
37 def __repr__(self):
38 return str(self.val)
39
40
41class AvlTree:
42 """
43 Self-balancing AVL tree implementation.
44
45 Maintains the AVL invariant: |balance_factor| <= 1 for all nodes.
46 Provides operations for adding, searching, deleting nodes,
47 finding successors, and traversing the tree.
48 """
49
50 def __init__(self):
51 """Initialize an empty AVL tree."""
52 self.root = None
53
54 def __contains__(self, val):
55 """Allow 'val in tree' syntax."""
56 return self.search(val) is not None
57
58 # -------------------------------------------------------------------------
59 # Height and Balance Factor Helpers
60 # -------------------------------------------------------------------------
61
62 def _height(self, node: AvlNode) -> int:
63 """Return height of node (0 for None)."""
64 return node.height if node else 0
65
66 def _update_height(self, node: AvlNode) -> None:
67 """Update node's height based on children."""
68 node.height = 1 + max(self._height(node.left), self._height(node.right))
69
70 def _balance_factor(self, node: AvlNode) -> int:
71 """Return balance factor: left_height - right_height."""
72 return self._height(node.left) - self._height(node.right) if node else 0
73
74 # -------------------------------------------------------------------------
75 # AVL Rotations
76 # -------------------------------------------------------------------------
77 # Rotations preserve BST ordering while rebalancing the tree.
78 #
79 # Left Rotation: Right Rotation:
80 # x y y x
81 # / \ => / \ / \ => / \
82 # A y x C x C A y
83 # / \ / \ / \ / \
84 # B C A B A B B C
85 #
86 # CRITICAL: Height updates must be bottom-up (child before parent)
87 # to avoid stale height values causing incorrect balance factors.
88 # -------------------------------------------------------------------------
89
90 def _rotate_left(self, x: AvlNode) -> AvlNode:
91 """Perform left rotation on node x, return new root."""
92 if not x or not x.right:
93 return x
94
95 y = x.right
96 x.right = y.left
97 y.left = x
98
99 self._update_height(x) # Child first (now lower in tree)
100 self._update_height(y) # Then parent
101
102 return y
103
104 def _rotate_right(self, y: AvlNode) -> AvlNode:
105 """Perform right rotation on node y, return new root."""
106 if not y or not y.left:
107 return y
108
109 x = y.left
110 y.left = x.right
111 x.right = y
112
113 self._update_height(y) # Child first (now lower in tree)
114 self._update_height(x) # Then parent
115
116 return x
117
118 def _rebalance(self, node: AvlNode) -> AvlNode:
119 """Rebalance node if needed and return the (potentially new) root."""
120 self._update_height(node)
121 balance = self._balance_factor(node)
122
123 # Left-heavy
124 if balance > 1:
125 if self._balance_factor(node.left) < 0:
126 # Left-Right case: rotate left child left, then node right
127 node.left = self._rotate_left(node.left)
128 return self._rotate_right(node)
129
130 # Right-heavy
131 if balance < -1:
132 if self._balance_factor(node.right) > 0:
133 # Right-Left case: rotate right child right, then node left
134 node.right = self._rotate_right(node.right)
135 return self._rotate_left(node)
136
137 return node
138
139 # -------------------------------------------------------------------------
140 # Public Operations
141 # -------------------------------------------------------------------------
142
143 def add(self, val) -> None:
144 """Add a value to the tree. Raises KeyError if duplicate."""
145
146 def _add_recursive(node: AvlNode, val) -> AvlNode:
147 if not node:
148 return AvlNode(val)
149
150 if val < node.val:
151 node.left = _add_recursive(node.left, val)
152 elif val > node.val:
153 node.right = _add_recursive(node.right, val)
154 else:
155 raise KeyError("Node with this value already in tree.")
156
157 return self._rebalance(node)
158
159 self.root = _add_recursive(self.root, val)
160
161 def find_successor(self, node: AvlNode) -> AvlNode | None:
162 """
163 Find the in-order successor of a node.
164 If node has right subtree: successor is leftmost in right subtree.
165 Otherwise: successor is first ancestor where node is in left subtree.
166 """
167 if node.right:
168 leftmost = node.right
169 while leftmost.left:
170 leftmost = leftmost.left
171 return leftmost
172
173 # Walk from root, tracking last left turn
174 successor = None
175 current = self.root
176 while current and current.val != node.val:
177 if node.val < current.val:
178 successor = current
179 current = current.left
180 else:
181 current = current.right
182 return successor
183
184 def delete(self, val) -> None:
185 """Delete a value from the tree. No-op if value not found."""
186
187 def _delete_recursive(node: AvlNode, val) -> AvlNode | None:
188 if not node:
189 return None
190
191 if val < node.val:
192 node.left = _delete_recursive(node.left, val)
193 elif val > node.val:
194 node.right = _delete_recursive(node.right, val)
195 else:
196 # Found node to delete
197 if node.left and node.right:
198 # Two children: replace with successor
199 successor = self.find_successor(node)
200 node.val = successor.val
201 node.right = _delete_recursive(node.right, successor.val)
202 elif node.left:
203 return node.left
204 elif node.right:
205 return node.right
206 else:
207 return None
208
209 return self._rebalance(node)
210
211 self.root = _delete_recursive(self.root, val)
212
213 def search(self, val) -> AvlNode | None:
214 """Search for a value. Returns the node if found, None otherwise."""
215 node = self.root
216 while node:
217 if val == node.val:
218 return node
219 node = node.left if val < node.val else node.right
220 return None
221
222 # -------------------------------------------------------------------------
223 # Traversal Methods
224 # -------------------------------------------------------------------------
225
226 class TraverseOrder(Enum):
227 """Enumeration of tree traversal orders."""
228
229 preorder = 1
230 inorder = 2
231 postorder = 3
232
233 def traverse(self, order: TraverseOrder) -> list:
234 """Traverse the tree in the specified order. Returns list of values."""
235 result = []
236
237 def _preorder(node: AvlNode):
238 if not node:
239 return
240 result.append(node.val)
241 _preorder(node.left)
242 _preorder(node.right)
243
244 def _inorder(node: AvlNode):
245 if not node:
246 return
247 _inorder(node.left)
248 result.append(node.val)
249 _inorder(node.right)
250
251 def _postorder(node: AvlNode):
252 if not node:
253 return
254 _postorder(node.left)
255 _postorder(node.right)
256 result.append(node.val)
257
258 traversals = {
259 AvlTree.TraverseOrder.preorder: _preorder,
260 AvlTree.TraverseOrder.inorder: _inorder,
261 AvlTree.TraverseOrder.postorder: _postorder,
262 }
263
264 if order not in traversals:
265 raise TypeError("Invalid traverse order passed.")
266
267 traversals[order](self.root)
268 return result
269
270 def inorder_traversal(self) -> list:
271 """Return values in sorted (in-order) sequence."""
272 return self.traverse(AvlTree.TraverseOrder.inorder)
273
274 def levels_traversal(self) -> dict[int, list]:
275 """Return values grouped by level (breadth-first). Level 0 is root."""
276 if not self.root:
277 return {}
278
279 levels = {}
280 queue = [(self.root, 0)]
281
282 for node, level in queue:
283 if level not in levels:
284 levels[level] = []
285 levels[level].append(node.val)
286
287 if node.left:
288 queue.append((node.left, level + 1))
289 if node.right:
290 queue.append((node.right, level + 1))
291
292 return levels
293
294 def min_depth(self) -> int:
295 """Return minimum depth (shortest path from root to a leaf). -1 if empty."""
296 if not self.root:
297 return -1
298
299 min_lvl = -1
300
301 def _dfs(node: AvlNode, depth: int):
302 nonlocal min_lvl
303 if not node:
304 return
305 if not node.left and not node.right:
306 if min_lvl == -1 or depth < min_lvl:
307 min_lvl = depth
308 _dfs(node.left, depth + 1)
309 _dfs(node.right, depth + 1)
310
311 _dfs(self.root, 0)
312 return min_lvl
313
314 def max_depth(self) -> int:
315 """Return maximum depth (longest path from root to a leaf). 0 for single node."""
316 if not self.root:
317 return -1
318
319 max_lvl = 0
320
321 def _dfs(node: AvlNode, depth: int):
322 nonlocal max_lvl
323 if not node:
324 return
325 if depth > max_lvl:
326 max_lvl = depth
327 _dfs(node.left, depth + 1)
328 _dfs(node.right, depth + 1)
329
330 _dfs(self.root, 0)
331 return max_lvl
332
333
334# Example usage:
335if __name__ == "__main__":
336 tree = AvlTree()
337
338 # Insert values
339 for val in [10, 20, 30, 40, 50, 25]:
340 tree.add(val)
341
342 # Traverse in different orders
343 print("Inorder: ", tree.traverse(AvlTree.TraverseOrder.inorder)) # [10, 20, 25, 30, 40, 50]
344 print("Preorder: ", tree.traverse(AvlTree.TraverseOrder.preorder))
345 print("Postorder:", tree.traverse(AvlTree.TraverseOrder.postorder))
346 print("By levels:", tree.levels_traversal())
347
348 # Search
349 print("Search 25:", tree.search(25)) # AvlNode with val=25
350 print("25 in tree:", 25 in tree) # True
351
352 # Delete
353 tree.delete(40)
354 tree.delete(30)
355 print("After deletions:", tree.inorder_traversal())
356