All Data Structures
Splay tree
Splay Trees: Implementation Guide
Core Implementation
Basic Splay Tree
class Node:
def __init__(self, key):
self.key = key
self.left = None
self.right = None
class SplayTree:
def __init__(self):
self.root = None
def _right_rotate(self, x):
y = x.left
x.left = y.right
y.right = x
return y
def _left_rotate(self, x):
y = x.right
x.right = y.left
y.left = x
return y
def _splay(self, root, key):
if not root or root.key == key:
return root
if root.key > key: # key is in left subtree
if not root.left:
return root
if root.left.key > key: # Zig-Zig (left left)
root.left.left = self._splay(root.left.left, key)
root = self._right_rotate(root)
elif root.left.key < key: # Zig-Zag (left right)
root.left.right = self._splay(root.left.right, key)
if root.left.right:
root.left = self._left_rotate(root.left)
return self._right_rotate(root) if root.left else root
else: # key is in right subtree
if not root.right:
return root
if root.right.key < key: # Zig-Zig (right right)
root.right.right = self._splay(root.right.right, key)
root = self._left_rotate(root)
elif root.right.key > key: # Zig-Zag (right left)
root.right.left = self._splay(root.right.left, key)
if root.right.left:
root.right = self._right_rotate(root.right)
return self._left_rotate(root) if root.right else root
def search(self, key):
self.root = self._splay(self.root, key)
return self.root and self.root.key == key
def insert(self, key):
if not self.root:
self.root = Node(key)
return
self.root = self._splay(self.root, key)
if self.root.key == key:
return
new_node = Node(key)
if self.root.key > key:
new_node.right = self.root
new_node.left = self.root.left
self.root.left = None
else:
new_node.left = self.root
new_node.right = self.root.right
self.root.right = None
self.root = new_node
def delete(self, key):
if not self.root:
return
self.root = self._splay(self.root, key)
if self.root.key != key:
return
if not self.root.left:
self.root = self.root.right
else:
temp = self.root
self.root = self._splay(self.root.left, key)
self.root.right = temp.right
Implementation Variations
1. Bottom-Up Splay Tree
class BottomUpSplayTree(SplayTree):
def _splay_bottom_up(self, root, key):
if not root:
return None
header = Node(None) # Dummy node
header.left = header.right = None
l = r = header
current = root
while True:
if key < current.key:
if not current.left:
break
if key < current.left.key:
current = self._right_rotate(current)
if not current.left:
break
r.left = current
r = current
current = current.left
elif key > current.key:
if not current.right:
break
if key > current.right.key:
current = self._left_rotate(current)
if not current.right:
break
l.right = current
l = current
current = current.right
else:
break
l.right = current.left
r.left = current.right
current.left = header.right
current.right = header.left
return current
2. Augmented Splay Tree with Size
class SizeNode(Node):
def __init__(self, key):
super().__init__(key)
self.size = 1
class SizeSplayTree(SplayTree):
def _update_size(self, node):
if node:
node.size = 1
if node.left:
node.size += node.left.size
if node.right:
node.size += node.right.size
def _right_rotate(self, x):
y = super()._right_rotate(x)
self._update_size(x)
self._update_size(y)
return y
def _left_rotate(self, x):
y = super()._left_rotate(x)
self._update_size(x)
self._update_size(y)
return y
def select(self, k):
"""Find kth smallest element"""
if not self.root:
return None
def _select_recursive(root, k):
left_size = root.left.size if root.left else 0
if k == left_size + 1:
return root.key
elif k <= left_size:
return _select_recursive(root.left, k)
else:
return _select_recursive(root.right, k - left_size - 1)
return _select_recursive(self.root, k)
Time Complexities
Operation | Amortized | Worst Case |
---|---|---|
Search | O(log n) | O(n) |
Insert | O(log n) | O(n) |
Delete | O(log n) | O(n) |
Select | O(log n) | O(n) |
Join | O(log n) | O(n) |
Split | O(log n) | O(n) |
Space Complexities
Implementation | Space Complexity | Additional Notes |
---|---|---|
Basic Splay Tree | O(n) | One node per element |
Bottom-Up Splay | O(n) | Additional temporary space |
Augmented with Size | O(n) | Extra size field per node |
Common Operations Patterns
-
Access Patterns
- Recently accessed elements
- Temporal locality
- Working set maintenance
-
Tree Rotations
- Zig step
- Zig-zig step
- Zig-zag step
-
Balance Operations
- Splaying
- Path compression
- Tree restructuring
-
Split/Join Operations
- Tree decomposition
- Tree merging
- Range operations
-
Query Operations
- Range queries
- Order statistics
- Predecessor/Successor
-
Statistical Operations
- Selection queries
- Rank queries
- Range counting
Edge Cases to Consider
- Empty Tree
def handle_empty(self):
if not self.root:
raise ValueError("Operation on empty tree")
- Single Node
def is_leaf(self, node):
return not (node.left or node.right)
- Duplicate Keys
def insert_with_duplicates(self, key):
if not self.root:
self.root = Node(key)
return
self.root = self._splay(self.root, key)
if self.root.key == key:
# Handle duplicate: create new node or update count
pass
Optimization Techniques
- Path Compression
def _compress_path(self, node, key):
"""Compress path during splay operation"""
if not node:
return None
ancestors = []
current = node
# Collect ancestors
while current and current.key != key:
ancestors.append(current)
current = current.left if key < current.key else current.right
# Compress path
for i in range(len(ancestors) - 1, 0, -1):
parent = ancestors[i-1]
child = ancestors[i]
if key < parent.key:
parent.left = child.right
child.right = parent
else:
parent.right = child.left
child.left = parent
- Bulk Operations
def bulk_insert(self, keys):
"""Optimized bulk insertion"""
keys = sorted(keys)
def build_balanced(start, end):
if start > end:
return None
mid = (start + end) // 2
node = Node(keys[mid])
node.left = build_balanced(start, mid - 1)
node.right = build_balanced(mid + 1, end)
return node
if not self.root:
self.root = build_balanced(0, len(keys) - 1)
else:
# Insert balanced subtree
temp = build_balanced(0, len(keys) - 1)
self.root = self._join(self.root, temp)
Memory Management
- Node Pool
class NodePool:
def __init__(self, size=1000):
self.pool = [Node(None) for _ in range(size)]
self.available = list(range(size))
def get_node(self, key):
if self.available:
idx = self.available.pop()
node = self.pool[idx]
node.key = key
node.left = node.right = None
return node
return Node(key)
def return_node(self, node):
if len(self.available) < len(self.pool):
node.key = None
node.left = node.right = None
self.available.append(self.pool.index(node))
- Memory-Efficient Node
class CompactNode:
__slots__ = ['key', 'left', 'right']
def __init__(self, key):
self.key = key
self.left = None
self.right = None
Performance Considerations
- Access Pattern Optimization
class CacheAwareSplayTree(SplayTree):
def __init__(self, cache_size=1000):
super().__init__()
self.cache = {}
self.cache_size = cache_size
def search(self, key):
# Check cache first
if key in self.cache:
return True
found = super().search(key)
if found and len(self.cache) < self.cache_size:
self.cache[key] = True
return found
- Rotation Cost Management
class AdaptiveSplayTree(SplayTree):
def __init__(self):
super().__init__()
self.rotation_count = 0
self.access_count = 0
def _should_splay(self):
if self.access_count < 100:
return True
return (self.rotation_count / self.access_count) < 0.5
def search(self, key):
self.access_count += 1
if not self._should_splay():
# Perform simple BST search instead
return self._bst_search(self.root, key)
return super().search(key)
Common Pitfalls
- Improper Splaying
# Wrong
def search_wrong(self, key):
node = self._find(key) # Simple BST search
return node is not None
# Correct
def search_correct(self, key):
self.root = self._splay(self.root, key)
return self.root and self.root.key == key
- Missing Root Update
# Wrong
def insert_wrong(self, key):
if not self.root:
self.root = Node(key)
return
self._splay(self.root, key) # Didn't update root
# Correct
def insert_correct(self, key):
if not self.root:
self.root = Node(key)
return
self.root = self._splay(self.root, key)
- Incorrect Rotation
# Wrong
def _rotate_wrong(self, x):
if x.left:
x.left, x.left.right = x.left.right, x # Lost references
# Correct
def _rotate_correct(self, x):
y = x.left
x.left = y.right
y.right = x
return y
On this page
- Splay Trees: Implementation Guide
- Core Implementation
- Basic Splay Tree
- Implementation Variations
- 1. Bottom-Up Splay Tree
- 2. Augmented Splay Tree with Size
- Time Complexities
- Space Complexities
- Common Operations Patterns
- Edge Cases to Consider
- Optimization Techniques
- Memory Management
- Performance Considerations
- Common Pitfalls