class AVLNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.height = 1
self.size = 1
class AVLTree:
def __init__(self):
self.root = None
def height(self, node):
return node.height if node else 0
def size(self, node):
return node.size if node else 0
def update_height(self, node):
node.height = max(self.height(node.left), self.height(node.right)) + 1
node.size = self.size(node.left) + self.size(node.right) + 1
def balance_factor(self, node):
return self.height(node.left) - self.height(node.right)
def rotate_right(self, y):
x = y.left
T2 = x.right
x.right = y
y.left = T2
self.update_height(y)
self.update_height(x)
return x
def rotate_left(self, x):
y = x.right
T2 = y.left
y.left = x
x.right = T2
self.update_height(x)
self.update_height(y)
return y
def insert(self, val):
self.root = self._insert_recursive(self.root, val)
def _insert_recursive(self, node, val):
if not node:
return AVLNode(val)
if val < node.val:
node.left = self._insert_recursive(node.left, val)
else:
node.right = self._insert_recursive(node.right, val)
self.update_height(node)
balance = self.balance_factor(node)
if balance > 1 and val < node.left.val:
return self.rotate_right(node)
if balance < -1 and val > node.right.val:
return self.rotate_left(node)
if balance > 1 and val > node.left.val:
node.left = self.rotate_left(node.left)
return self.rotate_right(node)
if balance < -1 and val < node.right.val:
node.right = self.rotate_right(node.right)
return self.rotate_left(node)
return node
def delete(self, val):
self.root = self._delete_recursive(self.root, val)
def _delete_recursive(self, node, val):
if not node:
return None
if val < node.val:
node.left = self._delete_recursive(node.left, val)
elif val > node.val:
node.right = self._delete_recursive(node.right, val)
else:
if not node.left:
return node.right
elif not node.right:
return node.left
successor = self._find_min(node.right)
node.val = successor.val
node.right = self._delete_recursive(node.right, successor.val)
if not node:
return None
self.update_height(node)
balance = self.balance_factor(node)
if balance > 1 and self.balance_factor(node.left) >= 0:
return self.rotate_right(node)
if balance > 1 and self.balance_factor(node.left) < 0:
node.left = self.rotate_left(node.left)
return self.rotate_right(node)
if balance < -1 and self.balance_factor(node.right) <= 0:
return self.rotate_left(node)
if balance < -1 and self.balance_factor(node.right) > 0:
node.right = self.rotate_right(node.right)
return self.rotate_left(node)
return node
def _find_min(self, node):
current = node
while current.left:
current = current.left
return current
class OrderStatisticAVL:
def get_rank(self, val):
"""Return number of nodes less than val"""
return self._get_rank_recursive(self.root, val)
def _get_rank_recursive(self, node, val):
if not node:
return 0
if val == node.val:
return self.size(node.left)
elif val < node.val:
return self._get_rank_recursive(node.left, val)
else:
return self.size(node.left) + 1 + self._get_rank_recursive(node.right, val)
def select(self, k):
"""Return kth smallest element (0-based)"""
return self._select_recursive(self.root, k)
def _select_recursive(self, node, k):
if not node:
return None
rank = self.size(node.left)
if rank == k:
return node.val
elif rank > k:
return self._select_recursive(node.left, k)
else:
return self._select_recursive(node.right, k - rank - 1)
def handle_empty(self):
return self.root is None
def is_leaf(self, node):
return not node.left and not node.right
def is_critical_imbalance(self, node):
return abs(self.balance_factor(node)) > 1
class OptimizedAVLNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.height = 1
self.subtree_size = 1
self.sum = val
class AVLNodeWithParent:
def __init__(self, val, parent=None):
self.val = val
self.left = None
self.right = None
self.parent = parent
self.height = 1
class NodePool:
def __init__(self, size=1000):
self.pool = [AVLNode(0) for _ in range(size)]
self.available = list(range(size))
def get_node(self, val):
if not self.available:
return AVLNode(val)
idx = self.available.pop()
node = self.pool[idx]
node.val = val
node.left = node.right = None
node.height = 1
return node
def clear_tree(self):
def recursive_clear(node):
if not node:
return
recursive_clear(node.left)
recursive_clear(node.right)
node.left = node.right = None
node.val = None
recursive_clear(self.root)
self.root = None
def batch_insert(self, values):
values.sort()
for val in values:
self.insert(val)
def optimized_rebalance(self, node):
balance = self.balance_factor(node)
if abs(balance) <= 1:
return node
return self.rebalance(node)
def update_height_wrong(self, node):
node.height = max(node.left.height, node.right.height) + 1
def update_height_correct(self, node):
node.height = max(self.height(node.left), self.height(node.right)) + 1
def rotate_right_wrong(self, y):
x = y.left
y.left = x.right
x.right = y
return x
def rotate_right_correct(self, y):
x = y.left
T2 = x.right
x.right = y
y.left = T2
self.update_height(y)
self.update_height(x)
return x
def balance_factor_wrong(self, node):
return self.height(node.right) - self.height(node.left)
def balance_factor_correct(self, node):
return self.height(node.left) - self.height(node.right)