Why does this code to check if a binary tree is balanced take time O(n log n) when it recomputes depths multiple times?

Shisui :

This code is meant to check if a binary tree is balanced (balanced being defined as a tree such that the heights of the two subtrees of any node never differ by more than one.

I understand the N part of the runtime O(NlogN). The N is because every node in the tree is visited at least once.

int getHeight(TreeNode root){
    if(root==null) return -1; //Base case
    return Math.max(getHeight(root.left), getHeight(root.right))+1; 
}

boolean isBalanced(TreeNode root){
    if(root == null) return true; //Base case

    int heightDiff = getHeight(root.left) - getHeight(root.right);

    if(Math.abs(heightDiff) > 1){
        return false;
    } else{ //Recurse
        return isBalanced(root.left) && isBalanced(root.right);
    }
}

What I don't understand is the logN part of the runtime O(NlogN). The code will trace every possible path from a node to the bottom of the tree. Therefore should the code be more like N2^N or something? How does one step by step come to the conclusion that the runtime is O(NlogN)?

templatetypedef :

I agree with you that the runtime of this code is not necessarily O(n log n). However, I don't believe that it will always trace out every path from a node to the bottom of the tree. For example, consider this tree:

                  *
                 /
                *
               /
              *

Here, computing the depths of the left and right subtrees will indeed visit every node once. However, because an imbalance is found between the left and right subtrees, the recursion stops without recursively exploring the left subtree. In other words, finding an example where the recursion has to do a lot of work is going to require some creativity.

You are correct that the baseline check for the height difference will take time Θ(n) because every node must be scanned. The concern with this code is that it might rescan nodes many, many times as it recomputes the height differences during the recursion. If we want this function to run for a really long time - not necessarily as long as possible, but for a long time - we'd want to make it so that

  • the left and right subtrees have roughly the same height, so that the recursion proceeds to the left subtree, but
  • the tree is extremely imbalanced, placing most of the nodes into the left subtree.

One way to do this is to create trees where the right subtree is just a long spine that happens to have the same height as the left subtree, but with way fewer nodes. Here's one possible sequence of trees that has this property:

                              *
                             / \
                *           *   *
               / \         / \   \
      *       *   *       *   *   *
     / \     / \   \     / \   \   \
*   *   *   *   *   *   *   *   *   *

Mechanically, each tree is formed by taking the previous tree and putting a rightward spine on top of it. Operationally, these trees are defined recursively as follows:

  • An order-0 tree is a single node.
  • An order-(k+1) tree is a node whose left child is an order-k tree and whose right child is a linked list of height k.

Notice that the number of nodes in an order-k tree is Θ(k2). You can see this by noticing that the trees have a nice triangular shape, where each layer has one more node in it than the previous one. Sums of the form 1 + 2 + 3 + ... + k work out to Θ(k2), and while we can be more precise than this, there really isn't a need to do so.

Now, what happens if we fire off this recursion on the root of any one of these trees? Well, the recursion will begin by computing the heights of the left and right subtrees, which will report that they have the same height as one another. It will then recursively explore the left subtree to see whether it's balanced. After doing some (large) amount of work, it'll find that the left subtree is not balanced, at which point the recursion won't branch to the right subtree. In other words, the amount of work done on an order-k tree is lower-bounded by

  • W(0) = 1 (there's a single node visited once), and
  • W(k+1) = W(k) + Θ(k2).

To see where the W(k+1) term comes from, notice that we begin by scanning every node in the tree and there are Θ(k2) nodes to scan, then recursively applying the procedure to the left subtree. Expanding this recurrence, we see that in an order-k tree, the total work done is

W(k) = Θ(k2) + W(k-1)

= Θ(k2 + (k - 1)2) + W(k - 2)

= Θ(k2 + (k - 1)2 + (k - 2)2) + W(k - 3)

...

= Θ(k2 + (k - 1)2 + ... + 22 + 12)

= Θ(k3).

This last step follows from the fact that the sum of the first k cubes works out to Θ(k3).

To finish things off, we have one more step. We've shown that order-k trees require Θ(k3) total work to process with this recursive algorithm. However, we'd like a runtime bound in terms of n, the total number of nodes in the tree, not k, the order of the tree. Using the fact that the number of nodes in a tree of order k is Θ(k2), we see that a tree with n nodes has order Θ(k1/2). Plugging this in, we see that for arbitrarily large n, we can make the total work done equal to Θ((n1/2)3) = Θ(n3/2), which exceeds the O(n log n) proposed bound you mentioned. I'm not sure whether this is the worst-case input for this algorithm, but it's certainly not a good one.

So yes, you are correct - the runtime is not O(n log n) in general. However, it is the case that if the tree is perfectly balanced, the runtime is indeed O(n log n). To see why, notice that if the tree is perfectly balanced, each recursive call will

  • do O(n) work scanning each node in the tree, then
  • make two recursive calls on smaller trees, each of which is approximately half as large as the previous one.

That gives the recurrence T(n) = 2T(n / 2) + O(n), which solves to O(n log n). But that's just one specific case, not the general case.

A concluding note - with a minor modification, this code can be made to run in time O(n) in all cases. Instead of recomputing the depth of each node, make an initial pass over the tree and annotate each node with its depth (either by setting some internal field equal to the depth or by having an auxiliary HashMap mapping each node to its depth). This can be done in time O(n). From there, recursively walking the tree and checking whether the left and right subtrees have heights that differ by at most one requires O(1) work per node across n total nodes for a total runtime of O(n).

Hope this helps!

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=131948&siteId=1