Problem Description

Two elements of a binary search tree (BST) are swapped by mistake.
Recover the tree without changing its structure.

Example 1

Input: [1,3,null,null,2]
  1
 /
3
 \
  2
Output: [3,1,null,null,2]
  3
 /
1
 \
  2

Example 2

Input: [3,1,4,null,null,2]

  3
 / \
1   4
   /
  2

Output: [2,1,4,null,null,3]

  2
 / \
1   4
   /
  3

Solution

First, we have to know how to check a valid BST: do an inorder scan of the tree and check nodes' values are in non-decreasing order. Knowing this, we have already got a hint of the solution. To correct the error, we have to know the place of the error. And to detect the error, we simply do an inorder scan and check whenever a node's value is less than the previous' value, there must exist an error. And with the knowledge of the error place, we can correct it by swapping these nodes' values.

But there're still some subtlety here. According to the problem description, the error is caused by a swap of two nodes' values. And here're two possible options, the first is that the two nodes are adjacent meaning after the inorder scan, we probably get something like:

[1,2,3,4,6,5,7,8]

where number 5 and 6 are swapped. Or, the two nodes are apart:

[1,7,3,4,5,6,2,8]

where number 2 and 7 are swapped.

Then the strategy is: after a complete inorder scan, if we detect the error once meaning this is the case 1: two nodes are adjacent. We can just swap the node and the previous node, and we're done. Or we detect the error twice, then we get the case 2. Just note here, the first time we detect the error, the node we want to swap really is the previous node: like when we compare 3 and 7, 7 is the node to be swapped. And the second time we detect the error, the current node needs to be swapped: that is the number 2. And this completes our strategy, the code would look like below.

a C++ code

class Solution {
public:
    void recoverTree(TreeNode* root) {
        TreeNode *x=nullptr, *y=nullptr, *prev=nullptr;
        helper(root, &prev, &x, &y);
        swap(x->val, y->val);
    }
    
    void helper(TreeNode* root, TreeNode** prev, TreeNode** x, TreeNode** y) {
        if(root==nullptr) return;
        helper(root->left, prev, x, y);
        if((*prev)!=nullptr && root->val < (*prev)->val){
            (*x)=root;
            if((*y)==nullptr) *y=*prev;
            else return;
        }
        (*prev) = root;
        helper(root->right, prev, x, y);
    }
};
文章目录