Binary Search Tree Iterators Inspired by Call Stacks

Binary Search Tree Iterators Inspired by Call Stacks video (41 minutes) (Spring 2021)

I recommend Binary tree traversal: Preorder, Inorder, Postorder by mycodeschool.

Example Binary Search Tree

This is the tree that we will traverse, via an iterator. So, all example nodes, call stacks, etc, will come from this tree.

            e
           / \
          b   f
         / \
        a   d
           /
          c

In-Order Traversal

The goal of the iterator we are making is to traverse a sorted collection in-order. As the iterator is incremented (++it), bigger or "higher value" elements are visited. Eg, with characters, the iterator will visit a, b, c, ...then finally, z.

For a BST of Nodes, each having a Left (smaller) subtree and Right (bigger) subtree, In-Order Traversal can be thought of as:

  1. Processing the left subtree
  2. Processing the node
  3. Processing the right subtree

(This algorithm applies recursively to all nodes and subtrees)

We have already seen an example of In-Order Traversal of a BST, in the print() method that has been defined for BST.

    void print() {
      if (left != nullptr) {
        left->print();
      }
      cout << value << endl;
      if (right != nullptr) {
        right->print();
      }
    }

For the example tree, the output of this print() method would be:

a
b
c
d
e
f

Recursive Call Stacks

The above print() method is recursive. It starts off by "drilling down" in the BST to the left-most node, which, having no left child, prints its own value and returns. This results in the printing of "a".

The recursive call stack is just that, a Stack. When a function recursively calls itself, several things may be stored in a recursive stack frame.

  1. An instruction pointer, telling the recursive call to "return to this certain place in instruction memory when finished"
  2. any local variables associated with the recursive functon
  3. In our example, since the print() method belongs to the Node struct, an implicit "this" pointer is stored. ("this" is a Node, for our example)

At the point where the recursive print() call is AT the a node, what would the stack of "this" pointers be?

            e
           / \
          b   f
         / \
        a   d
           /
          c

For our example tree, it would be (writing the stack upside down or "from the top":) (** This represents only a small part of the recursive call stack frames. We are only considering the "this" pointers, aka, the Nodes we are visiting. Other things like an instruction pointer and other local variables may exist in each stack frame.)

e
b
a

Then, the a node is processed (printed), and popped off the stack, which is now:

e
b

At this point, the print() function, scoped to b, will print its value ("b"), pop b off the recursive call stack, and then, visit its own right subtree. The root of this subtree is d. D is added to the stack, but not printed yet -- we need to visit D's left subtree first!

e
d

Once print() is called, scoped to d, it will visit its own left subtree, adding c to the stack.

e
d
c

C node has no children, so its value is printed. C is popped off the stack.

e
d

D node has no right subtree, so its value is then printed, and D is popped off the stack.

e

We are now back at E, the root! E's left subtree has been processed. So, time to print the node's value ('e'), pop e off the stack, and then visit the right subtree. Once e is popped off the stack, it's empty.

(empty stack!)

Time to visit the right subtree. F is e's right child, so f is pushed onto the stack.

f

F node has no children, so its value would be printed, and f is popped off the stack. The print() function is done, and all its recursive calls are closed out.

Extracting an algorithm out of the print() recursive call stack

We have just seen how print() function traverses a BST in-order. Similar to the recursive call stack, we can traverse the tree, starting at the root, and maintain a stack of pointers, representing "nodes still to be processed". As long as we have an algorithm for how to push node pointers onto this stack, and pop them off as they are processed, we will be in good shape.

Things to keep in mind

Here are a few things to keep in mind as we build this algorithm.

  1. When processing a BST, you (generally) always start at the root.
  2. Our goal is to iterate through a BST in sorted order, or, "in-order traversal"
  3. Starting from the root, we will perform a "findMin()" type of operation, to get the begin() iterator - the smallest element in the tree. We will keep going left until we hit a nullptr.
  4. The end() iterator will simply be a nullptr. This represents a position that is "past the biggest element in the tree"

The general algorithm for a subtree (applies recursively)

  1. Push a node, the subtree root, onto the stack
  2. Travel left as far as possible, pushing visited nodes onto the stack
  3. When a node has no left child, process the node, and pop it from the stack
  4. Visit the node's right child (which we saved before the pop in step 3!)
  5. As of step 4, we are now at the root node of a new subtree. Go back to step 1.

Code example

The stack is maintained as a vector.

In the Iterator(Node* root) constructor, you can see that it travels left, adding nodes to the stack, until it hits a nullptr. This sets the iterator up in a begin() position.

In the operator++ method, you can see that we add the right child to the stack, remove the current node (popping it off the stack), and then travel as far left as possible, in the right subtree that we are now located in, pushing nodes onto the stack along the way.

Also note, we capture the current node's right child BEFORE popping it off of the stack. This way, we don't lose our place.

  class Iterator {
    vector<Node*> stack;
  public:
    // constructor
    Iterator() {}
    Iterator(Node* root) {
      for (Node* cur = root; cur != nullptr; cur = cur->left) {
        stack.push_back(cur);
      }
    }
    // compare (!=)
    bool operator!=(const Iterator& rhs) const {
      return stack != rhs.stack;
    }
    // dereference (*)
    pair<Key, Value> operator*() const {
      Node* top = stack.back();
      return pair<Key, Value>(top->key, top->value);
    }
    // increment (++)
    Iterator& operator++() {
      Node* cur = stack.back()->right;
      stack.pop_back();
      for (; cur != nullptr; cur = cur->left) {
        stack.push_back(cur);
      }
      return *this;
    }

    // This is for demo purposes only, to show how Iterator works.
    void print_stack_keys() {
      for (const Node* cur : stack) {
        cout << cur->key << ' ';
      }
      cout << endl;
    }
  };

Thanks

Thanks to Brian Foster for writing up these notes based on the video lectures.