Splay Trees: Implementation Guide

Core Implementation

Basic Splay Tree

class Node:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None

class SplayTree:
    def __init__(self):
        self.root = None
    
    def _right_rotate(self, x):
        y = x.left
        x.left = y.right
        y.right = x
        return y
    
    def _left_rotate(self, x):
        y = x.right
        x.right = y.left
        y.left = x
        return y
    
    def _splay(self, root, key):
        if not root or root.key == key:
            return root
        
        if root.key > key:  # key is in left subtree
            if not root.left:
                return root
            
            if root.left.key > key:  # Zig-Zig (left left)
                root.left.left = self._splay(root.left.left, key)
                root = self._right_rotate(root)
            elif root.left.key < key:  # Zig-Zag (left right)
                root.left.right = self._splay(root.left.right, key)
                if root.left.right:
                    root.left = self._left_rotate(root.left)
            
            return self._right_rotate(root) if root.left else root
            
        else:  # key is in right subtree
            if not root.right:
                return root
            
            if root.right.key < key:  # Zig-Zig (right right)
                root.right.right = self._splay(root.right.right, key)
                root = self._left_rotate(root)
            elif root.right.key > key:  # Zig-Zag (right left)
                root.right.left = self._splay(root.right.left, key)
                if root.right.left:
                    root.right = self._right_rotate(root.right)
            
            return self._left_rotate(root) if root.right else root
    
    def search(self, key):
        self.root = self._splay(self.root, key)
        return self.root and self.root.key == key
    
    def insert(self, key):
        if not self.root:
            self.root = Node(key)
            return
        
        self.root = self._splay(self.root, key)
        
        if self.root.key == key:
            return
        
        new_node = Node(key)
        if self.root.key > key:
            new_node.right = self.root
            new_node.left = self.root.left
            self.root.left = None
        else:
            new_node.left = self.root
            new_node.right = self.root.right
            self.root.right = None
        
        self.root = new_node
    
    def delete(self, key):
        if not self.root:
            return
        
        self.root = self._splay(self.root, key)
        if self.root.key != key:
            return
        
        if not self.root.left:
            self.root = self.root.right
        else:
            temp = self.root
            self.root = self._splay(self.root.left, key)
            self.root.right = temp.right

Implementation Variations

1. Bottom-Up Splay Tree

class BottomUpSplayTree(SplayTree):
    def _splay_bottom_up(self, root, key):
        if not root:
            return None
        
        header = Node(None)  # Dummy node
        header.left = header.right = None
        
        l = r = header
        current = root
        
        while True:
            if key < current.key:
                if not current.left:
                    break
                if key < current.left.key:
                    current = self._right_rotate(current)
                    if not current.left:
                        break
                r.left = current
                r = current
                current = current.left
            elif key > current.key:
                if not current.right:
                    break
                if key > current.right.key:
                    current = self._left_rotate(current)
                    if not current.right:
                        break
                l.right = current
                l = current
                current = current.right
            else:
                break
        
        l.right = current.left
        r.left = current.right
        current.left = header.right
        current.right = header.left
        
        return current

2. Augmented Splay Tree with Size

class SizeNode(Node):
    def __init__(self, key):
        super().__init__(key)
        self.size = 1

class SizeSplayTree(SplayTree):
    def _update_size(self, node):
        if node:
            node.size = 1
            if node.left:
                node.size += node.left.size
            if node.right:
                node.size += node.right.size
    
    def _right_rotate(self, x):
        y = super()._right_rotate(x)
        self._update_size(x)
        self._update_size(y)
        return y
    
    def _left_rotate(self, x):
        y = super()._left_rotate(x)
        self._update_size(x)
        self._update_size(y)
        return y
    
    def select(self, k):
        """Find kth smallest element"""
        if not self.root:
            return None
        
        def _select_recursive(root, k):
            left_size = root.left.size if root.left else 0
            
            if k == left_size + 1:
                return root.key
            elif k <= left_size:
                return _select_recursive(root.left, k)
            else:
                return _select_recursive(root.right, k - left_size - 1)
        
        return _select_recursive(self.root, k)

Time Complexities

OperationAmortizedWorst Case
SearchO(log n)O(n)
InsertO(log n)O(n)
DeleteO(log n)O(n)
SelectO(log n)O(n)
JoinO(log n)O(n)
SplitO(log n)O(n)

Space Complexities

ImplementationSpace ComplexityAdditional Notes
Basic Splay TreeO(n)One node per element
Bottom-Up SplayO(n)Additional temporary space
Augmented with SizeO(n)Extra size field per node

Common Operations Patterns

  1. Access Patterns

    • Recently accessed elements
    • Temporal locality
    • Working set maintenance
  2. Tree Rotations

    • Zig step
    • Zig-zig step
    • Zig-zag step
  3. Balance Operations

    • Splaying
    • Path compression
    • Tree restructuring
  4. Split/Join Operations

    • Tree decomposition
    • Tree merging
    • Range operations
  5. Query Operations

    • Range queries
    • Order statistics
    • Predecessor/Successor
  6. Statistical Operations

    • Selection queries
    • Rank queries
    • Range counting

Edge Cases to Consider

  1. Empty Tree
def handle_empty(self):
    if not self.root:
        raise ValueError("Operation on empty tree")
  1. Single Node
def is_leaf(self, node):
    return not (node.left or node.right)
  1. Duplicate Keys
def insert_with_duplicates(self, key):
    if not self.root:
        self.root = Node(key)
        return
    
    self.root = self._splay(self.root, key)
    if self.root.key == key:
        # Handle duplicate: create new node or update count
        pass

Optimization Techniques

  1. Path Compression
def _compress_path(self, node, key):
    """Compress path during splay operation"""
    if not node:
        return None
    
    ancestors = []
    current = node
    
    # Collect ancestors
    while current and current.key != key:
        ancestors.append(current)
        current = current.left if key < current.key else current.right
    
    # Compress path
    for i in range(len(ancestors) - 1, 0, -1):
        parent = ancestors[i-1]
        child = ancestors[i]
        if key < parent.key:
            parent.left = child.right
            child.right = parent
        else:
            parent.right = child.left
            child.left = parent
  1. Bulk Operations
def bulk_insert(self, keys):
    """Optimized bulk insertion"""
    keys = sorted(keys)
    
    def build_balanced(start, end):
        if start > end:
            return None
        
        mid = (start + end) // 2
        node = Node(keys[mid])
        node.left = build_balanced(start, mid - 1)
        node.right = build_balanced(mid + 1, end)
        return node
    
    if not self.root:
        self.root = build_balanced(0, len(keys) - 1)
    else:
        # Insert balanced subtree
        temp = build_balanced(0, len(keys) - 1)
        self.root = self._join(self.root, temp)

Memory Management

  1. Node Pool
class NodePool:
    def __init__(self, size=1000):
        self.pool = [Node(None) for _ in range(size)]
        self.available = list(range(size))
    
    def get_node(self, key):
        if self.available:
            idx = self.available.pop()
            node = self.pool[idx]
            node.key = key
            node.left = node.right = None
            return node
        return Node(key)
    
    def return_node(self, node):
        if len(self.available) < len(self.pool):
            node.key = None
            node.left = node.right = None
            self.available.append(self.pool.index(node))
  1. Memory-Efficient Node
class CompactNode:
    __slots__ = ['key', 'left', 'right']
    
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None

Performance Considerations

  1. Access Pattern Optimization
class CacheAwareSplayTree(SplayTree):
    def __init__(self, cache_size=1000):
        super().__init__()
        self.cache = {}
        self.cache_size = cache_size
    
    def search(self, key):
        # Check cache first
        if key in self.cache:
            return True
        
        found = super().search(key)
        if found and len(self.cache) < self.cache_size:
            self.cache[key] = True
        
        return found
  1. Rotation Cost Management
class AdaptiveSplayTree(SplayTree):
    def __init__(self):
        super().__init__()
        self.rotation_count = 0
        self.access_count = 0
    
    def _should_splay(self):
        if self.access_count < 100:
            return True
        return (self.rotation_count / self.access_count) < 0.5
    
    def search(self, key):
        self.access_count += 1
        if not self._should_splay():
            # Perform simple BST search instead
            return self._bst_search(self.root, key)
        return super().search(key)

Common Pitfalls

  1. Improper Splaying
# Wrong
def search_wrong(self, key):
    node = self._find(key)  # Simple BST search
    return node is not None

# Correct
def search_correct(self, key):
    self.root = self._splay(self.root, key)
    return self.root and self.root.key == key
  1. Missing Root Update
# Wrong
def insert_wrong(self, key):
    if not self.root:
        self.root = Node(key)
        return
    self._splay(self.root, key)  # Didn't update root

# Correct
def insert_correct(self, key):
    if not self.root:
        self.root = Node(key)
        return
    self.root = self._splay(self.root, key)
  1. Incorrect Rotation
# Wrong
def _rotate_wrong(self, x):
    if x.left:
        x.left, x.left.right = x.left.right, x  # Lost references

# Correct
def _rotate_correct(self, x):
    y = x.left
    x.left = y.right
    y.right = x
    return y