class AVLNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.height = 1
self.size = 1 # For order statistics
class AVLTree:
def __init__(self):
self.root = None
def height(self, node):
return node.height if node else 0
def size(self, node):
return node.size if node else 0
def update_height(self, node):
node.height = max(self.height(node.left), self.height(node.right)) + 1
node.size = self.size(node.left) + self.size(node.right) + 1
def balance_factor(self, node):
return self.height(node.left) - self.height(node.right)
def rotate_right(self, y):
x = y.left
T2 = x.right
x.right = y
y.left = T2
self.update_height(y)
self.update_height(x)
return x
def rotate_left(self, x):
y = x.right
T2 = y.left
y.left = x
x.right = T2
self.update_height(x)
self.update_height(y)
return y
def insert(self, val):
self.root = self._insert_recursive(self.root, val)
def _insert_recursive(self, node, val):
# Standard BST insert
if not node:
return AVLNode(val)
if val < node.val:
node.left = self._insert_recursive(node.left, val)
else:
node.right = self._insert_recursive(node.right, val)
self.update_height(node)
balance = self.balance_factor(node)
# Left Left Case
if balance > 1 and val < node.left.val:
return self.rotate_right(node)
# Right Right Case
if balance < -1 and val > node.right.val:
return self.rotate_left(node)
# Left Right Case
if balance > 1 and val > node.left.val:
node.left = self.rotate_left(node.left)
return self.rotate_right(node)
# Right Left Case
if balance < -1 and val < node.right.val:
node.right = self.rotate_right(node.right)
return self.rotate_left(node)
return node